diff --git a/README.md b/README.md index 4a4ce792..d3d92865 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,8 @@
diff --git a/README_ja.md b/README_ja.md new file mode 100644 index 00000000..073b0c48 --- /dev/null +++ b/README_ja.md @@ -0,0 +1,300 @@ + +
+
+
+
+
+ +
+ + + +
+
+
+
+
+ +代表的な例をいくつか挙げると: + +NLP: + +* [nlp_gpt3_text-generation_2.7B](https://modelscope.cn/models/damo/nlp_gpt3_text-generation_2.7B) + +* [ChatYuan-large](https://modelscope.cn/models/ClueAI/ChatYuan-large) + +* [mengzi-t5-base](https://modelscope.cn/models/langboat/mengzi-t5-base) + +* [nlp_csanmt_translation_en2zh](https://modelscope.cn/models/damo/nlp_csanmt_translation_en2zh) + +* [nlp_raner_named-entity-recognition_chinese-base-news](https://modelscope.cn/models/damo/nlp_raner_named-entity-recognition_chinese-base-news) + +* [nlp_structbert_word-segmentation_chinese-base](https://modelscope.cn/models/damo/nlp_structbert_word-segmentation_chinese-base) + +* [Erlangshen-RoBERTa-330M-Sentiment](https://modelscope.cn/models/fengshenbang/Erlangshen-RoBERTa-330M-Sentiment) + +* [nlp_convai_text2sql_pretrain_cn](https://modelscope.cn/models/damo/nlp_convai_text2sql_pretrain_cn) + +マルチモーダル: + +* [multi-modal_clip-vit-base-patch16_zh](https://modelscope.cn/models/damo/multi-modal_clip-vit-base-patch16_zh) + +* [ofa_pretrain_base_zh](https://modelscope.cn/models/damo/ofa_pretrain_base_zh) + +* [Taiyi-Stable-Diffusion-1B-Chinese-v0.1](https://modelscope.cn/models/fengshenbang/Taiyi-Stable-Diffusion-1B-Chinese-v0.1) + +* [mplug_visual-question-answering_coco_large_en](https://modelscope.cn/models/damo/mplug_visual-question-answering_coco_large_en) + +CV: + +* [cv_controlnet_controllable-image-generation_nine-annotators](https://modelscope.cn/models/dienstag/cv_controlnet_controllable-image-generation_nine-annotators/summary) + +* [cv_tinynas_object-detection_damoyolo](https://modelscope.cn/models/damo/cv_tinynas_object-detection_damoyolo) + +* [cv_unet_person-image-cartoon_compound-models](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon_compound-models) + +* [cv_convnextTiny_ocr-recognition-general_damo](https://modelscope.cn/models/damo/cv_convnextTiny_ocr-recognition-general_damo) + +* [cv_resnet18_human-detection](https://modelscope.cn/models/damo/cv_resnet18_human-detection) + +* [cv_resnet50_face-detection_retinaface](https://modelscope.cn/models/damo/cv_resnet50_face-detection_retinaface) + +* [cv_unet_image-matting](https://modelscope.cn/models/damo/cv_unet_image-matting) + +* [cv_F3Net_product-segmentation](https://modelscope.cn/models/damo/cv_F3Net_product-segmentation) + +* [cv_resnest101_general_recognition](https://modelscope.cn/models/damo/cv_resnest101_general_recognition) + + +音声: + +* [speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch](https://modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch) + +* [speech_sambert-hifigan_tts_zh-cn_16k](https://modelscope.cn/models/damo/speech_sambert-hifigan_tts_zh-cn_16k) + +* [speech_charctc_kws_phone-xiaoyun](https://modelscope.cn/models/damo/speech_charctc_kws_phone-xiaoyun) + +* [u2pp_conformer-asr-cn-16k-online](https://modelscope.cn/models/wenet/u2pp_conformer-asr-cn-16k-online) + +* [speech_fsmn_vad_zh-cn-16k-common-pytorch](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary) + +* [punc_ct-transformer_zh-cn-common-vocab272727-pytorch](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/summary) + +* [speech_frcrn_ans_cirm_16k](https://modelscope.cn/models/damo/speech_frcrn_ans_cirm_16k) + +* [speech_dfsmn_aec_psm_16k](https://modelscope.cn/models/damo/speech_dfsmn_aec_psm_16k) + + + +科学用 AI: + +* [uni-fold-monomer](https://modelscope.cn/models/DPTech/uni-fold-monomer/summary) + +* [uni-fold-multimer](https://modelscope.cn/models/DPTech/uni-fold-multimer/summary) + +**注:** ModelScope のほとんどのモデルは公開されており、アカウント登録なしで modelscope のウェブサイト([www.modelscope.cn](www.modelscope.cn))からダウンロードすることができます。modelscope のライブラリや git が提供する api を使用してモデルをダウンロードするには、[モデルのダウンロード](https://modelscope.cn/docs/%E6%A8%A1%E5%9E%8B%E7%9A%84%E4%B8%8B%E8%BD%BD)の説明を参照してください。 + +# クイックツアー + +様々なタスクに対して、`pipeline` による推論、`Trainer` による微調整と評価のための統一されたインターフェースを提供します。 + +入力の種類(画像、テキスト、音声、動画...)を問わず、推論パイプラインはわずか数行のコードで実装することができます。: + +```python +>>> from modelscope.pipelines import pipeline +>>> word_segmentation = pipeline('word-segmentation',model='damo/nlp_structbert_word-segmentation_chinese-base') +>>> word_segmentation('今天天气不错,适合出去游玩') +{'output': '今天 天气 不错 , 适合 出去 游玩'} +``` + +画像があれば、ポートレート・マット(別名、背景除去)は次のコード・スニペットで実現できます: + + + +```python +>>> import cv2 +>>> from modelscope.pipelines import pipeline + +>>> portrait_matting = pipeline('portrait-matting') +>>> result = portrait_matting('https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/image_matting.png') +>>> cv2.imwrite('result.png', result['output_img']) +``` + +背景を除去した出力画像は次のようになります: + + + +ファインチューニングと評価も、トレーニングデータセットとトレーナーをセットアップする数行のコードで行うことができ、モデルのトレーニングと評価の重い作業は `traner.train()` と `trainer.evaluate()` インターフェースの実装に +カプセル化されています。 + +例えば、gpt3 の基本モデル(1.3B)を中国語詩のデータセットでファインチューニングすることで、中国語詩の生成に使用できるモデルを得ることができる。 + +```python +>>> from modelscope.metainfo import Trainers +>>> from modelscope.msdatasets import MsDataset +>>> from modelscope.trainers import build_trainer + +>>> train_dataset = MsDataset.load('chinese-poetry-collection', split='train'). remap_columns({'text1': 'src_txt'}) +>>> eval_dataset = MsDataset.load('chinese-poetry-collection', split='test').remap_columns({'text1': 'src_txt'}) +>>> max_epochs = 10 +>>> tmp_dir = './gpt3_poetry' + +>>> kwargs = dict( + model='damo/nlp_gpt3_text-generation_1.3B', + train_dataset=train_dataset, + eval_dataset=eval_dataset, + max_epochs=max_epochs, + work_dir=tmp_dir) + +>>> trainer = build_trainer(name=Trainers.gpt3_trainer, default_args=kwargs) +>>> trainer.train() +``` + +# ModelScope ライブラリを使用する理由 + +1. 統一された簡潔なユーザーインターフェースは、異なるタスクや異なるモデル用に抽象化されている。モデルの推論とトレーニングは、それぞれわずか 3 行と 10 行のコードで実装できる。ModelScope コミュニティで異なる分野のモデルを探索するのに便利です。ModelScope に統合されたモデルはすべてすぐに使用できるため、教育現場でも産業現場でも、AI を簡単に使い始めることができます。 + +2. ModelScope は、モデル中心の開発とアプリケーション体験を提供します。モデルのトレーニング、推論、エクスポート、デプロイメントのサポートを合理化し、ユーザーが ModelScope エコシステムに基づいて独自の MLO を構築することを容易にします。 + +3. モデルの推論とトレーニングのプロセスでは、モジュール設計が導入され、豊富な機能モジュールの実装が提供され、ユーザーが独自のモデルの推論、トレーニング、その他のプロセスをカスタマイズするのに便利です。 + +4. 分散モデル学習、特に大規模モデルに対しては、データ並列、モデル並列、ハイブリッド並列など、豊富な学習ストラテジーサポートを提供する。 + +# インストール + +## Docker + +ModelScope ライブラリは現在、PyTorch、TensorFlow、ONNX を含む、モデルの学習と推論のための一般的なディープラーニングフレームワークをサポートしています。すべてのリリースは、Python 3.7+、Pytorch 1.8+、Tensorflow1.15、または Tensorflow2.0+ でテストされ、実行されます。 + +ModelScope のすべてのモデルをすぐに使えるようにするため、すべてのリリースで公式の docker イメージが提供されています。開発者はこの docker イメージをベースに、環境のインストールや設定をすべて省略して直接使用することができます。現在、CPU イメージと GPU イメージの最新バージョンは以下から入手できます: + +CPU docker イメージ +```shell +# py37 +registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-py37-torch1.11.0-tf1.15.5-1.6.1 + +# py38 +registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-py38-torch1.11.0-tf1.15.5-1.6.1 +``` + +GPU docker イメージ +```shell +# py37 +registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-cuda11.3.0-py37-torch1.11.0-tf1.15.5-1.6.1 + +# py38 +registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-cuda11.3.0-py38-torch1.11.0-tf1.15.5-1.6.1 +``` + +## ローカル Python 環境のセットアップ + +pip と conda を使って、ModelScope のローカル環境を構築することもできます。 ローカルの Python 環境を構築するには [anaconda](https://docs.anaconda.com/anaconda/install/) をお勧めします: + +```shell +conda create -n modelscope python=3.7 +conda activate modelscope +``` + +PyTorch または TensorFlow は、それぞれのモデルの要件に応じて個別にインストールすることができます。 +* pytorch のインストール [doc](https://pytorch.org/get-started/locally/) +* Tensorflow のインストール [doc](https://www.tensorflow.org/install/pip) + +必要な機械学習フレームワークをインストールした後、以下のように modelscope ライブラリをインストールします: + +モデル/データセットのダウンロードを試したり、modelscope フレームワークで遊びたいだけなら、modelscope のコア・コンポーネントをインストールすることができます: +```shell +pip install modelscope +``` + +マルチモーダルモデルを使いたい場合: +```shell +pip install modelscope[multi-modal] +``` + +nlp モデルを使いたい場合: +```shell +pip install modelscope[nlp] -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html +``` + +CV モデルを使いたい場合: +```shell +pip install modelscope[cv] -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html +``` + +オーディオモデルを使用したい場合: +```shell +pip install modelscope[audio] -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html +``` + +科学モデルを使いたい場合: +```shell +pip install modelscope[science] -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html +``` + +`備考`: +1. 現在、一部のオーディオタスクモデルは python3.7、tensorflow1.15.4 の Linux 環境のみに対応しています。他のほとんどのモデルは Windows と Mac(x86) にインストールして使うことができます。 + +2. オーディオ分野では、wav ファイルの処理にサードパーティ製のライブラリ SoundFile を使用している機種がある。Linux では、SoundFile の libsndfile([doc link](https://github.com/bastibe/python-soundfile#installation)) を手動でインストールする必要があります。Windows や MacOS では、ユーザーが操作しなくても自動的にインストールされる。例えば、Ubuntu の場合、以下のコマンドでインストールできます: + ```shell + sudo apt-get update + sudo apt-get install libsndfile1 + ``` + +3. コンピュータビジョンのモデルによっては mmcv-full が必要です。mmcv [インストールガイド](https://github.com/open-mmlab/mmcv#installation)を参照してください。最小限のインストールは以下の通りです: + + ```shell + pip uninstall mmcv # mmcv をインストールしている場合は、アンインストールしてください + pip install -U openmim + mim install mmcv-full + ``` + + + +# 詳細 + +私たちは、以下のような追加書類を提供します: +* [より詳細なインストールガイド](https://modelscope.cn/docs/%E7%8E%AF%E5%A2%83%E5%AE%89%E8%A3%85) +* [タスクの紹介](https://modelscope.cn/docs/%E4%BB%BB%E5%8A%A1%E7%9A%84%E4%BB%8B%E7%BB%8D) +* [モデル推論にパイプラインを使う](https://modelscope.cn/docs/%E6%A8%A1%E5%9E%8B%E7%9A%84%E6%8E%A8%E7%90%86Pipeline) +* [ファインチューニング例](https://modelscope.cn/docs/%E6%A8%A1%E5%9E%8B%E7%9A%84%E8%AE%AD%E7%BB%83Train) +* [データの前処理](https://modelscope.cn/docs/%E6%95%B0%E6%8D%AE%E7%9A%84%E9%A2%84%E5%A4%84%E7%90%86) +* [評価](https://modelscope.cn/docs/%E6%A8%A1%E5%9E%8B%E7%9A%84%E8%AF%84%E4%BC%B0) +* [ModelScope に自分のモデルを投稿する](https://modelscope.cn/docs/ModelScope%E6%A8%A1%E5%9E%8B%E6%8E%A5%E5%85%A5%E6%B5%81%E7%A8%8B%E6%A6%82%E8%A7%88) + +# ライセンス + +このプロジェクトのライセンスは [Apache License (Version 2.0)](https://github.com/modelscope/modelscope/blob/master/LICENSE) です。 diff --git a/README_zh.md b/README_zh.md index f5401f33..7cac99fb 100644 --- a/README_zh.md +++ b/README_zh.md @@ -21,7 +21,8 @@
diff --git a/examples/pytorch/auto_speech_recognition/finetune_speech_recognition.py b/examples/pytorch/auto_speech_recognition/finetune_speech_recognition.py
index 4d62f66f..47af0b90 100644
--- a/examples/pytorch/auto_speech_recognition/finetune_speech_recognition.py
+++ b/examples/pytorch/auto_speech_recognition/finetune_speech_recognition.py
@@ -1,15 +1,19 @@
import os
from modelscope.metainfo import Trainers
-from modelscope.msdatasets.audio.asr_dataset import ASRDataset
+from modelscope.msdatasets.dataset_cls.custom_datasets import ASRDataset
from modelscope.trainers import build_trainer
+from modelscope.utils.constant import DownloadMode
def modelscope_finetune(params):
if not os.path.exists(params.output_dir):
os.makedirs(params.output_dir, exist_ok=True)
# dataset split ["train", "validation"]
- ds_dict = ASRDataset.load(params.data_path, namespace='speech_asr')
+ ds_dict = ASRDataset.load(
+ params.data_path,
+ namespace='speech_asr',
+ download_mode=params.download_mode)
kwargs = dict(
model=params.model,
data_dir=ds_dict,
@@ -36,5 +40,6 @@ if __name__ == '__main__':
# 如果dataset_type="large",batch_bins单位为毫秒,
params.max_epoch = 50 # 最大训练轮数
params.lr = 0.00005 # 设置学习率
+ params.download_mode = DownloadMode.FORCE_REDOWNLOAD # 重新下载数据,否则设置为默认值DownloadMode.REUSE_DATASET_IF_EXISTS
modelscope_finetune(params)
diff --git a/examples/pytorch/baichuan/finetune_baichuan.py b/examples/pytorch/baichuan/finetune_baichuan.py
index 075ebc31..353f5023 100644
--- a/examples/pytorch/baichuan/finetune_baichuan.py
+++ b/examples/pytorch/baichuan/finetune_baichuan.py
@@ -219,9 +219,7 @@ kwargs = dict(
train_dataset=train_dataset,
eval_dataset=validation_dataset,
seed=args.seed,
- cfg_modify_fn=cfg_modify_fn,
- # No placement for model, leave the model to `device_map`
- device='cpu' if args.device_map else 'gpu')
+ cfg_modify_fn=cfg_modify_fn)
trainer: EpochBasedTrainer = build_trainer(
name=args.trainer, default_args=kwargs)
diff --git a/examples/pytorch/chatglm6b/chatglm_trainer.py b/examples/pytorch/chatglm6b/chatglm_trainer.py
index b34563bd..84167713 100644
--- a/examples/pytorch/chatglm6b/chatglm_trainer.py
+++ b/examples/pytorch/chatglm6b/chatglm_trainer.py
@@ -6,7 +6,7 @@ from transformers.deepspeed import is_deepspeed_zero3_enabled
from modelscope import EpochBasedTrainer, get_logger
-logger = get_logger(__name__)
+logger = get_logger()
class Seq2SeqTrainer(EpochBasedTrainer):
@@ -16,6 +16,8 @@ class Seq2SeqTrainer(EpochBasedTrainer):
if ignore_pad_token_for_loss:
tokens = np.where(tokens != -100, tokens,
self.tokenizer.pad_token_id)
+ tokens = np.where(tokens < self.tokenizer.vocab_size, tokens,
+ self.tokenizer.pad_token_id)
return [
t for t in self.tokenizer.batch_decode(
tokens, skip_special_tokens=True) if t != ''
@@ -59,7 +61,9 @@ class Seq2SeqTrainer(EpochBasedTrainer):
gen_kwargs['input_ids'] = generation_inputs
gen_kwargs['pad_token_id'] = self.tokenizer.pad_token_id
- generated_tokens = self.model.generate(**gen_kwargs)
+ self.model.eval()
+ with torch.no_grad():
+ generated_tokens = self.model.generate(**gen_kwargs)
generated_tokens = generated_tokens[:, generation_inputs.size()[-1]:]
# in case the batch is shorter than max length, the output should be padded
diff --git a/examples/pytorch/chatglm6b/finetune.py b/examples/pytorch/chatglm6b/finetune.py
index 40eb8720..0e31ce28 100644
--- a/examples/pytorch/chatglm6b/finetune.py
+++ b/examples/pytorch/chatglm6b/finetune.py
@@ -143,6 +143,14 @@ class Chatglm6bArguments(TrainingArgs):
metadata={'help': 'The lora alpha'},
)
+ use_amp: int = field(
+ default=0,
+ metadata={
+ 'help':
+ 'Whether to use amp(automatic mixed precision) to train the model.'
+ },
+ )
+
args = Chatglm6bArguments(eval_metrics='chatglm').parse_cli()
print(args)
@@ -160,6 +168,13 @@ def cfg_modify_fn(cfg):
cfg.merge_from_dict(config)
else:
cfg = config
+ if args.use_amp:
+ if not getattr(cfg.train, 'hooks', None):
+ cfg.train.hooks = []
+ cfg.train.hooks.append({
+ 'type': 'TorchAMPOptimizerHook',
+ # Optional loss_scale parameter here.
+ })
if cfg.train.lr_scheduler.type == 'LinearLR':
cfg.train.lr_scheduler['total_iters'] = \
int(len(train_dataset) / cfg.train.dataloader.batch_size_per_gpu) * cfg.train.max_epochs
@@ -193,13 +208,15 @@ model_config['model'] = ConfigDict({
'type': config['model']['type'],
})
-if config['model']['type'] == 'chatglm6b':
- model_config['model']['pre_seq_len'] = args.pre_seq_len
- model_config['model']['prefix_projection'] = args.prefix_projection
-
+model_config['model']['pre_seq_len'] = args.pre_seq_len
+model_config['model']['prefix_projection'] = args.prefix_projection
tokenizer = ChatGLMTokenizer.from_pretrained(model_dir, trust_remote_code=True)
+
+device_map_kwargs = {}
+if args.use_lora != 0 and torch.cuda.device_count() > 1:
+ device_map_kwargs['device_map'] = 'auto'
model = Model.from_pretrained(
- model_dir, cfg_dict=model_config, device_map='auto')
+ model_dir, cfg_dict=model_config, **device_map_kwargs)
if args.ptuning_checkpoint is not None:
# Evaluation
@@ -230,7 +247,10 @@ if args.use_lora != 0:
rank=args.lora_rank,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout)
- model = model.bfloat16()
+ if args.use_amp:
+ model = model.float()
+ else:
+ model = model.bfloat16()
Swift.prepare_model(model, lora_config)
prefix = args.source_prefix if args.source_prefix is not None else ''
@@ -333,13 +353,10 @@ def preprocess_function_train(examples):
pad_len = max_seq_length - len(input_ids)
input_ids = input_ids + [tokenizer.pad_token_id] * pad_len
- if config['model']['type'] == 'chatglm6b':
- labels = labels + [tokenizer.pad_token_id] * pad_len
- if args.ignore_pad_token_for_loss:
- labels = [(lb if lb != tokenizer.pad_token_id else -100)
- for lb in labels]
- else:
- labels = labels + [-100] * pad_len
+ labels = labels + [tokenizer.pad_token_id] * pad_len
+ if args.ignore_pad_token_for_loss:
+ labels = [(lb if lb != tokenizer.pad_token_id else -100)
+ for lb in labels]
model_inputs['input_ids'].append(input_ids)
model_inputs['labels'].append(labels)
@@ -371,8 +388,7 @@ data_collator = DataCollatorForSeq2Seq(
padding=False)
model.gradient_checkpointing_enable()
-if config['model']['type'] == 'chatglm6b':
- model.enable_input_require_grads()
+model.enable_input_require_grads()
# import torch
# model = torch.nn.DataParallel(model).cuda()
@@ -384,8 +400,6 @@ trainer = Seq2SeqTrainer(
seed=args.seed,
data_collator=data_collator,
remove_unused_data=True,
- # No placement for model, leave the model to `device_map`
- device='cpu',
cfg_modify_fn=cfg_modify_fn)
trainer.tokenizer = tokenizer
trainer.train()
diff --git a/examples/pytorch/chatglm6b/run_train_chatglm2_ptuning_adv_v2.sh b/examples/pytorch/chatglm6b/run_train_chatglm2_ptuning_adv_v2.sh
new file mode 100644
index 00000000..582c464c
--- /dev/null
+++ b/examples/pytorch/chatglm6b/run_train_chatglm2_ptuning_adv_v2.sh
@@ -0,0 +1,26 @@
+PRE_SEQ_LEN=128
+LR=2e-2
+
+PYTHONPATH=. python examples/pytorch/chatglm6b/finetune.py \
+ --train_dataset_name AdvertiseGen/train.json \
+ --val_dataset_name AdvertiseGen/dev.json \
+ --prompt_column content \
+ --response_column summary \
+ --model "ZhipuAI/chatglm2-6b" \
+ --max_source_length 64 \
+ --max_target_length 128 \
+ --per_device_train_batch_size 16 \
+ --per_device_eval_batch_size 1 \
+ --train.optimizer.options.cumulative_iters 1 \
+ --max_epochs 1 \
+ --save_strategy 'by_step' \
+ --save_interval 1000 \
+ --lr $LR \
+ --eval_strategy "by_step" \
+ --eval_interval 1000 \
+ --lr_strategy 'by_step' \
+ --task 'chat' \
+ --model.type 'chatglm2-6b' \
+ --pre_seq_len $PRE_SEQ_LEN \
+ --quantization_bit 4 \
+ --work_dir ptuning_adv_target \
diff --git a/examples/pytorch/chatglm6b/text_generation_metric.py b/examples/pytorch/chatglm6b/text_generation_metric.py
index 2083453a..536bbe06 100644
--- a/examples/pytorch/chatglm6b/text_generation_metric.py
+++ b/examples/pytorch/chatglm6b/text_generation_metric.py
@@ -53,7 +53,7 @@ class TextGenerationMetric(Metric):
}
for pred, label in zip(preds, labels):
hypothesis = list(jieba.cut(pred))
- if len(hypothesis) == 0:
+ if len(hypothesis) == 0 or ''.join(hypothesis) == '.':
hypothesis = ['']
reference = list(jieba.cut(label))
rouge = Rouge()
diff --git a/examples/pytorch/llm/_parser.py b/examples/pytorch/llm/_parser.py
new file mode 100644
index 00000000..480cfdce
--- /dev/null
+++ b/examples/pytorch/llm/_parser.py
@@ -0,0 +1,69 @@
+import os
+from dataclasses import dataclass, field
+from typing import List, Optional, Tuple, Type, TypeVar, Union
+
+import torch
+from torch import device as Device
+from transformers import HfArgumentParser
+
+from modelscope import get_logger
+
+logger = get_logger()
+
+
+def _format_device(device: Union[List[int], str]) -> Tuple[List[int], str]:
+ if isinstance(device, list):
+ device_ids = device
+ device_str = ','.join([str(d) for d in device])
+ else:
+ device_ids = [int(d) for d in device.split(',') if d != '-1']
+ device_str = device
+ device_str = device_str.replace(' ', '')
+ return device_ids, device_str
+
+
+def select_device(device: Union[List[int], str]) -> Device:
+ """Call this function before cuda is initialized.
+ device: e.g. []: 'cpu', [0], [0, 1, 2]
+ e.g. '-1': 'cpu', '0', '0,1,2'
+ """
+ if torch.cuda.is_initialized():
+ logger.warning('CUDA has been initialized! Device selection fails!')
+ return torch.device('cuda:0')
+
+ device_ids, device_str = _format_device(device)
+ os.environ['CUDA_VISIBLE_DEVICES'] = device_str
+ log_s = 'Using device: '
+ if len(device_ids) == 0:
+ master_device: str = 'cpu'
+ log_s += 'cpu'
+ else:
+ assert torch.cuda.is_available(
+ ) and torch.cuda.device_count() >= len(device_ids)
+ master_device = 'cuda:0'
+ log_s += f'cuda:{device_str}'
+ logger.info(log_s)
+ return torch.device(master_device)
+
+
+_T = TypeVar('_T')
+
+
+def parse_args(class_type: Type[_T],
+ argv: Optional[List[str]] = None) -> Tuple[_T, List[str]]:
+ parser = HfArgumentParser([class_type])
+ args, remaining_args = parser.parse_args_into_dataclasses(
+ argv, return_remaining_strings=True)
+ logger.info(f'args: {args}')
+ return args, remaining_args
+
+
+@dataclass
+class DeviceArguments:
+ device: str = '0' # e.g. '-1'; '0'; '0,1'
+
+
+def parse_device(argv: Optional[List[str]] = None) -> List[str]:
+ args, remaining_args = parse_args(DeviceArguments, argv)
+ select_device(args.device)
+ return remaining_args
diff --git a/examples/pytorch/llm/llm_infer.py b/examples/pytorch/llm/llm_infer.py
new file mode 100644
index 00000000..614e3d36
--- /dev/null
+++ b/examples/pytorch/llm/llm_infer.py
@@ -0,0 +1,123 @@
+# ### Setting up experimental environment.
+
+if __name__ == '__main__':
+ # Avoid cuda initialization caused by library import (e.g. peft, accelerate)
+ from _parser import *
+ # argv = parse_device(['--device', '1'])
+ argv = parse_device()
+
+from utils import *
+
+
+@dataclass
+class InferArguments:
+ model_type: str = field(
+ default='baichuan-7b', metadata={'choices': list(MODEL_MAPPER.keys())})
+ sft_type: str = field(
+ default='lora', metadata={'choices': ['lora', 'full']})
+ ckpt_path: str = '/path/to/your/iter_xxx.pth'
+ eval_human: bool = False # False: eval test_dataset
+ ignore_args_error: bool = True # False: notebook compatibility
+
+ dataset: str = field(
+ default='alpaca-en,alpaca-zh',
+ metadata={'help': f'dataset choices: {list(DATASET_MAPPER.keys())}'})
+ dataset_seed: int = 42
+ dataset_sample: Optional[int] = None
+ dataset_test_size: float = 0.01
+ prompt: str = DEFAULT_PROMPT
+ max_length: Optional[int] = 2048
+
+ lora_target_modules: Optional[List[str]] = None
+ lora_rank: int = 8
+ lora_alpha: int = 32
+ lora_dropout_p: float = 0.1
+
+ max_new_tokens: int = 512
+ temperature: float = 0.9
+ top_k: int = 50
+ top_p: float = 0.9
+
+ def __post_init__(self):
+ if self.lora_target_modules is None:
+ self.lora_target_modules = MODEL_MAPPER[self.model_type]['lora_TM']
+
+ if not os.path.isfile(self.ckpt_path):
+ raise ValueError(
+ f'Please enter a valid ckpt_path: {self.ckpt_path}')
+
+
+def llm_infer(args: InferArguments) -> None:
+ # ### Loading Model and Tokenizer
+ support_bf16 = torch.cuda.is_bf16_supported()
+ if not support_bf16:
+ logger.warning(f'support_bf16: {support_bf16}')
+ model, tokenizer, _ = get_model_tokenizer(
+ args.model_type, torch_dtype=torch.bfloat16)
+
+ # ### Preparing lora
+ if args.sft_type == 'lora':
+ lora_config = LoRAConfig(
+ replace_modules=args.lora_target_modules,
+ rank=args.lora_rank,
+ lora_alpha=args.lora_alpha,
+ lora_dropout=args.lora_dropout_p,
+ pretrained_weights=args.ckpt_path)
+ logger.info(f'lora_config: {lora_config}')
+ model = Swift.prepare_model(model, lora_config)
+ elif args.sft_type == 'full':
+ state_dict = torch.load(args.ckpt_path, map_location='cpu')
+ model.load_state_dict(state_dict)
+ else:
+ raise ValueError(f'args.sft_type: {args.sft_type}')
+
+ # ### Inference
+ tokenize_func = partial(
+ tokenize_function,
+ tokenizer=tokenizer,
+ prompt=args.prompt,
+ max_length=args.max_length)
+ streamer = TextStreamer(
+ tokenizer, skip_prompt=True, skip_special_tokens=True)
+ generation_config = GenerationConfig(
+ max_new_tokens=args.max_new_tokens,
+ temperature=args.temperature,
+ top_k=args.top_k,
+ top_p=args.top_p,
+ do_sample=True,
+ pad_token_id=tokenizer.eos_token_id)
+ logger.info(f'generation_config: {generation_config}')
+
+ if args.eval_human:
+ while True:
+ instruction = input('<<< ')
+ data = {'instruction': instruction}
+ input_ids = tokenize_func(data)['input_ids']
+ inference(input_ids, model, tokenizer, streamer, generation_config)
+ print('-' * 80)
+ else:
+ dataset = get_dataset(args.dataset)
+ _, test_dataset = process_dataset(dataset, args.dataset_test_size,
+ args.dataset_sample,
+ args.dataset_seed)
+ mini_test_dataset = test_dataset.select(range(10))
+ del dataset
+ for data in mini_test_dataset:
+ output = data['output']
+ data['output'] = None
+ input_ids = tokenize_func(data)['input_ids']
+ inference(input_ids, model, tokenizer, streamer, generation_config)
+ print()
+ print(f'[LABELS]{output}')
+ print('-' * 80)
+ # input('next[ENTER]')
+
+
+if __name__ == '__main__':
+ args, remaining_argv = parse_args(InferArguments, argv)
+ if len(remaining_argv) > 0:
+ if args.ignore_args_error:
+ logger.warning(f'remaining_argv: {remaining_argv}')
+ else:
+ raise ValueError(f'remaining_argv: {remaining_argv}')
+ llm_infer(args)
diff --git a/examples/pytorch/llm/llm_sft.py b/examples/pytorch/llm/llm_sft.py
new file mode 100644
index 00000000..a7dabf77
--- /dev/null
+++ b/examples/pytorch/llm/llm_sft.py
@@ -0,0 +1,266 @@
+# ### Setting up experimental environment.
+"""
+# Install the latest version of modelscope from source
+git clone https://github.com/modelscope/modelscope.git
+cd modelscope
+pip install .
+
+conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
+pip install numpy pandas -U # Resolve torchmetrics dependencies and update numpy
+pip install matplotlib scikit-learn -U
+pip install transformers datasets -U
+pip install tqdm tensorboard torchmetrics sentencepiece charset_normalizer -U
+pip install accelerate transformers_stream_generator -U
+"""
+
+if __name__ == '__main__':
+ # Avoid cuda initialization caused by library import (e.g. peft, accelerate)
+ from _parser import *
+ # argv = parse_device(['--device', '1'])
+ argv = parse_device()
+
+from utils import *
+
+
+@dataclass
+class SftArguments:
+ seed: int = 42
+ model_type: str = field(
+ default='baichuan-7b', metadata={'choices': list(MODEL_MAPPER.keys())})
+ # baichuan-7b: 'lora': 16G; 'full': 80G
+ sft_type: str = field(
+ default='lora', metadata={'choices': ['lora', 'full']})
+ ignore_args_error: bool = True # False: notebook compatibility
+
+ dataset: str = field(
+ default='alpaca-en,alpaca-zh',
+ metadata={'help': f'dataset choices: {list(DATASET_MAPPER.keys())}'})
+ dataset_seed: int = 42
+ dataset_sample: Optional[int] = None
+ dataset_test_size: float = 0.01
+ prompt: str = DEFAULT_PROMPT
+ max_length: Optional[int] = 2048
+
+ lora_target_modules: Optional[List[str]] = None
+ lora_rank: int = 8
+ lora_alpha: int = 32
+ lora_dropout_p: float = 0.1
+
+ gradient_checkpoint: bool = True
+ batch_size: int = 1
+ max_epochs: int = 1
+ learning_rate: Optional[float] = None
+ weight_decay: float = 0.01
+ n_accumulate_grad: int = 16
+ grad_clip_norm: float = 1.
+ warmup_iters: int = 200
+
+ save_trainer_state: Optional[bool] = None
+ eval_interval: int = 500
+ last_save_interval: Optional[int] = None
+ last_max_checkpoint_num: int = 1
+ best_max_checkpoint_num: int = 1
+ logging_interval: int = 5
+ tb_interval: int = 5
+
+ def __post_init__(self):
+ if self.sft_type == 'lora':
+ if self.learning_rate is None:
+ self.learning_rate = 1e-4
+ if self.save_trainer_state is None:
+ self.save_trainer_state = True
+ if self.last_save_interval is None:
+ self.last_save_interval = self.eval_interval
+ elif self.sft_type == 'full':
+ if self.learning_rate is None:
+ self.learning_rate = 1e-5
+ if self.save_trainer_state is None:
+ self.save_trainer_state = False # save disk space
+ if self.last_save_interval is None:
+ # Saving the model takes a long time
+ self.last_save_interval = self.eval_interval * 4
+ else:
+ raise ValueError(f'sft_type: {self.sft_type}')
+
+ if self.lora_target_modules is None:
+ self.lora_target_modules = MODEL_MAPPER[self.model_type]['lora_TM']
+
+
+def llm_sft(args: SftArguments) -> None:
+ seed_everything(args.seed)
+
+ # ### Loading Model and Tokenizer
+ support_bf16 = torch.cuda.is_bf16_supported()
+ if not support_bf16:
+ logger.warning(f'support_bf16: {support_bf16}')
+ model, tokenizer, model_dir = get_model_tokenizer(
+ args.model_type, torch_dtype=torch.bfloat16)
+
+ if args.gradient_checkpoint:
+ # baichuan-13b does not implement the `get_input_embeddings` function
+ if args.model_type == 'baichuan-13b':
+ model.get_input_embeddings = MethodType(
+ lambda self: self.model.embed_tokens, model)
+ model.gradient_checkpointing_enable()
+ model.enable_input_require_grads()
+
+ # ### Preparing lora
+ if args.sft_type == 'lora':
+ lora_config = LoRAConfig(
+ replace_modules=args.lora_target_modules,
+ rank=args.lora_rank,
+ lora_alpha=args.lora_alpha,
+ lora_dropout=args.lora_dropout_p)
+ logger.info(f'lora_config: {lora_config}')
+ model = Swift.prepare_model(model, lora_config)
+
+ show_freeze_layers(model)
+ print_model_info(model)
+ # check the device and dtype of the model
+ _p: Tensor = list(model.parameters())[-1]
+ logger.info(f'device: {_p.device}, dtype: {_p.dtype}')
+
+ # ### Loading Dataset
+ dataset = get_dataset(args.dataset)
+ train_dataset, val_dataset = process_dataset(dataset,
+ args.dataset_test_size,
+ args.dataset_sample,
+ args.dataset_seed)
+ tokenize_func = partial(
+ tokenize_function,
+ tokenizer=tokenizer,
+ prompt=args.prompt,
+ max_length=args.max_length)
+ train_dataset = train_dataset.map(tokenize_func)
+ val_dataset = val_dataset.map(tokenize_func)
+ del dataset
+ # Data analysis
+ stat_dataset(train_dataset)
+ stat_dataset(val_dataset)
+ data_collator = partial(data_collate_fn, tokenizer=tokenizer)
+ print_example(train_dataset[0], tokenizer)
+
+ # ### Setting Config
+ cfg_file = os.path.join(model_dir, 'configuration.json')
+
+ T_max = get_T_max(
+ len(train_dataset), args.batch_size, args.max_epochs, True)
+ work_dir = get_work_dir(f'runs/{args.model_type}')
+ config = Config({
+ 'train': {
+ 'dataloader': {
+ 'batch_size_per_gpu': args.batch_size,
+ 'workers_per_gpu': 1,
+ 'shuffle': True,
+ 'drop_last': True,
+ 'pin_memory': True
+ },
+ 'max_epochs':
+ args.max_epochs,
+ 'work_dir':
+ work_dir,
+ 'optimizer': {
+ 'type': 'AdamW',
+ 'lr': args.learning_rate,
+ 'weight_decay': args.weight_decay,
+ 'options': {
+ 'cumulative_iters': args.n_accumulate_grad,
+ 'grad_clip': {
+ 'norm_type': 2,
+ 'max_norm': args.grad_clip_norm
+ }
+ }
+ },
+ 'lr_scheduler': {
+ 'type': 'CosineAnnealingLR',
+ 'T_max': T_max,
+ 'eta_min': args.learning_rate * 0.1,
+ 'options': {
+ 'by_epoch': False,
+ 'warmup': {
+ 'type': 'LinearWarmup',
+ 'warmup_ratio': 0.1,
+ 'warmup_iters': args.warmup_iters
+ }
+ }
+ },
+ 'hooks': [
+ {
+ 'type': 'CheckpointHook',
+ 'by_epoch': False,
+ 'interval': args.last_save_interval,
+ 'max_checkpoint_num': args.last_max_checkpoint_num,
+ 'save_trainer_state': args.save_trainer_state
+ },
+ {
+ 'type': 'EvaluationHook',
+ 'by_epoch': False,
+ 'interval': args.eval_interval
+ },
+ {
+ 'type': 'BestCkptSaverHook',
+ 'metric_key': 'loss',
+ 'save_best': True,
+ 'rule': 'min',
+ 'max_checkpoint_num': args.best_max_checkpoint_num,
+ 'save_trainer_state': args.save_trainer_state
+ },
+ {
+ 'type': 'TextLoggerHook',
+ 'by_epoch': True, # Whether EpochBasedTrainer is used
+ 'interval': args.logging_interval
+ },
+ {
+ 'type': 'TensorboardHook',
+ 'by_epoch': False,
+ 'interval': args.tb_interval
+ }
+ ]
+ },
+ 'evaluation': {
+ 'dataloader': {
+ 'batch_size_per_gpu': args.batch_size,
+ 'workers_per_gpu': 1,
+ 'shuffle': False,
+ 'drop_last': False,
+ 'pin_memory': True
+ },
+ 'metrics': [{
+ 'type': 'my_metric',
+ 'vocab_size': tokenizer.vocab_size
+ }]
+ }
+ })
+
+ # ### Finetuning
+
+ def cfg_modify_fn(cfg: Config) -> Config:
+ cfg.update(config)
+ return cfg
+
+ trainer = EpochBasedTrainer(
+ model=model,
+ cfg_file=cfg_file,
+ data_collator=data_collator,
+ train_dataset=train_dataset,
+ eval_dataset=val_dataset,
+ remove_unused_data=True,
+ seed=42,
+ cfg_modify_fn=cfg_modify_fn,
+ )
+
+ trainer.train()
+
+ # ### Visualization
+ tb_dir = os.path.join(work_dir, 'tensorboard_output')
+ plot_images(tb_dir, ['loss'], 0.9)
+
+
+if __name__ == '__main__':
+ args, remaining_argv = parse_args(SftArguments, argv)
+ if len(remaining_argv) > 0:
+ if args.ignore_args_error:
+ logger.warning(f'remaining_argv: {remaining_argv}')
+ else:
+ raise ValueError(f'remaining_argv: {remaining_argv}')
+ llm_sft(args)
diff --git a/examples/pytorch/llm/run_infer.sh b/examples/pytorch/llm/run_infer.sh
new file mode 100644
index 00000000..aa1a1a04
--- /dev/null
+++ b/examples/pytorch/llm/run_infer.sh
@@ -0,0 +1,7 @@
+#!/bin/bash
+
+python llm_infer.py \
+ --device 0,1 \
+ --model_type openbuddy-llama2-13b \
+ --ckpt_path "runs/openbuddy-llama2-13b/vx_xxx/output_best/pytorch_model.bin" \
+ --eval_human true
diff --git a/examples/pytorch/llm/run_sft.sh b/examples/pytorch/llm/run_sft.sh
new file mode 100644
index 00000000..3a6d9ff4
--- /dev/null
+++ b/examples/pytorch/llm/run_sft.sh
@@ -0,0 +1,9 @@
+#!/bin/bash
+
+DATE=$(date +"%Y%m%d-%H%M%S")
+nohup python llm_sft.py \
+ --device 0,1 \
+ --model_type openbuddy-llama2-13b \
+ --dataset alpaca-en,alpaca-zh \
+ --dataset_sample 20000 \
+&> train_$DATE.out &
diff --git a/examples/pytorch/llm/utils/__init__.py b/examples/pytorch/llm/utils/__init__.py
new file mode 100644
index 00000000..e4772c03
--- /dev/null
+++ b/examples/pytorch/llm/utils/__init__.py
@@ -0,0 +1,5 @@
+from _parser import *
+
+from .dataset import *
+from .models import *
+from .utils import *
diff --git a/examples/pytorch/llm/utils/dataset.py b/examples/pytorch/llm/utils/dataset.py
new file mode 100644
index 00000000..3035ba78
--- /dev/null
+++ b/examples/pytorch/llm/utils/dataset.py
@@ -0,0 +1,72 @@
+from typing import Optional, Tuple
+
+import numpy as np
+from datasets import Dataset as HfDataset
+from datasets import concatenate_datasets
+from numpy.random import RandomState
+
+from modelscope import MsDataset
+
+
+def _processing_alpaca(dataset: HfDataset) -> HfDataset:
+ instruction = dataset['instruction']
+ input_ = dataset['input']
+ res = []
+ for inst, inp in zip(instruction, input_):
+ if inp is not None and inp != '':
+ if inp.startswith('输入:'):
+ inp = inp[3:]
+ inst = f'{inst}\n{inp}'
+ res.append(inst)
+ dataset = HfDataset.from_dict({
+ 'instruction': res,
+ 'output': dataset['output']
+ })
+ return dataset
+
+
+def get_alpaca_en_dataset() -> HfDataset:
+ dataset_en: HfDataset = MsDataset.load(
+ 'AI-ModelScope/alpaca-gpt4-data-en', split='train').to_hf_dataset()
+ dataset_en = dataset_en.remove_columns(['text'])
+ return _processing_alpaca(dataset_en)
+
+
+def get_alpaca_zh_dataset() -> HfDataset:
+ dataset_zh: HfDataset = MsDataset.load(
+ 'AI-ModelScope/alpaca-gpt4-data-zh', split='train').to_hf_dataset()
+ return _processing_alpaca(dataset_zh)
+
+
+def get_seed(random_state: RandomState) -> int:
+ seed_max = np.iinfo(np.int32).max
+ seed = random_state.randint(0, seed_max)
+ return seed
+
+
+def process_dataset(dataset: HfDataset, dataset_test_size: float,
+ dataset_sample: Optional[int],
+ dataset_seed: int) -> Tuple[HfDataset, HfDataset]:
+ random_state = np.random.RandomState(dataset_seed)
+ if dataset_sample is not None:
+ index = random_state.permutation(len(dataset))[:dataset_sample]
+ dataset = dataset.select(index)
+ dataset = dataset.train_test_split(
+ dataset_test_size, seed=get_seed(random_state))
+ return dataset['train'], dataset['test']
+
+
+DATASET_MAPPER = {
+ 'alpaca-en': get_alpaca_en_dataset,
+ 'alpaca-zh': get_alpaca_zh_dataset,
+}
+
+
+def get_dataset(dataset_names: str) -> HfDataset:
+ dataset_name_list = dataset_names.split(',')
+ dataset_list = []
+ for dataset_name in dataset_name_list:
+ get_function = DATASET_MAPPER[dataset_name]
+ dataset_list.append(get_function())
+ dataset = concatenate_datasets(dataset_list)
+ return dataset
diff --git a/examples/pytorch/llm/utils/models.py b/examples/pytorch/llm/utils/models.py
new file mode 100644
index 00000000..c95df561
--- /dev/null
+++ b/examples/pytorch/llm/utils/models.py
@@ -0,0 +1,133 @@
+from typing import NamedTuple
+
+import torch
+from torch import dtype as Dtype
+
+from modelscope import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, Model,
+ get_logger, read_config, snapshot_download)
+from modelscope.models.nlp.chatglm2 import ChatGLM2Config, ChatGLM2Tokenizer
+
+logger = get_logger()
+
+
+def _add_special_token(tokenizer):
+ if tokenizer.eos_token_id is None:
+ tokenizer.eos_token_id = 2
+ if tokenizer.bos_token_id is None:
+ tokenizer.bos_token_id = 1
+ if tokenizer.pad_token_id is None:
+ tokenizer.pad_token_id = 0
+ logger.info(f'bos_token_id: {tokenizer.bos_token_id}, '
+ f'eos_token_id: {tokenizer.eos_token_id}, '
+ f'pad_token_id: {tokenizer.pad_token_id}')
+
+
+def get_model_tokenizer_default(model_dir: str,
+ load_model: bool = True,
+ add_special_token: bool = True,
+ torch_dtype: Dtype = torch.float16):
+ """load from an independent repository"""
+ model_config = AutoConfig.from_pretrained(
+ model_dir, trust_remote_code=True)
+ model_config.torch_dtype = torch_dtype
+ logger.info(f'model_config: {model_config}')
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_dir, trust_remote_code=True)
+ model = None
+ if load_model:
+ model = AutoModelForCausalLM.from_pretrained(
+ model_dir,
+ config=model_config,
+ device_map='auto',
+ torch_dtype=torch_dtype,
+ trust_remote_code=True)
+
+ if add_special_token:
+ _add_special_token(tokenizer)
+ return model, tokenizer
+
+
+def get_model_tokenizer_chatglm2(model_dir: str,
+ load_model: bool = True,
+ add_special_token: bool = True,
+ torch_dtype: Dtype = torch.float16):
+ """load from ms library"""
+ config = read_config(model_dir)
+ logger.info(config)
+ model_config = ChatGLM2Config.from_pretrained(model_dir)
+ model_config.torch_dtype = torch_dtype
+ logger.info(model_config)
+ tokenizer = ChatGLM2Tokenizer.from_pretrained(model_dir)
+ model = None
+ if load_model:
+ model = Model.from_pretrained(
+ model_dir,
+ cfg_dict=config,
+ config=model_config,
+ device_map='auto',
+ torch_dtype=torch_dtype)
+ if add_special_token:
+ _add_special_token(tokenizer)
+ return model, tokenizer
+
+
+class LoRATM(NamedTuple):
+ # default lora target modules
+ baichuan = ['W_pack']
+ chatglm2 = ['query_key_value']
+ llama2 = ['q_proj', 'k_proj', 'v_proj']
+
+
+# Reference: 'https://modelscope.cn/models/{model_id}/summary'
+MODEL_MAPPER = {
+ 'baichuan-7b': {
+ 'model_id': 'baichuan-inc/baichuan-7B',
+ 'revision': 'v1.0.7',
+ 'lora_TM': LoRATM.baichuan
+ },
+ 'baichuan-13b': {
+ 'model_id': 'baichuan-inc/Baichuan-13B-Base',
+ 'revision': 'v1.0.3',
+ 'lora_TM': LoRATM.baichuan
+ },
+ 'chatglm2': {
+ 'model_id': 'ZhipuAI/chatglm2-6b',
+ 'revision': 'v1.0.6',
+ 'get_function': get_model_tokenizer_chatglm2,
+ 'lora_TM': LoRATM.chatglm2
+ },
+ 'llama2-7b': {
+ 'model_id': 'modelscope/Llama-2-7b-ms',
+ 'revision': 'v1.0.2',
+ 'ignore_file_pattern': [r'.+\.bin$'], # use safetensors
+ 'lora_TM': LoRATM.llama2
+ },
+ 'llama2-13b': {
+ 'model_id': 'modelscope/Llama-2-13b-ms',
+ 'revision': 'v1.0.2',
+ 'ignore_file_pattern': [r'.+\.bin$'],
+ 'lora_TM': LoRATM.llama2
+ },
+ 'openbuddy-llama2-13b': {
+ 'model_id': 'OpenBuddy/openbuddy-llama2-13b-v8.1-fp16',
+ 'lora_TM': LoRATM.llama2
+ }
+}
+
+
+def get_model_tokenizer(model_type: str,
+ load_model: bool = True,
+ add_special_token: bool = True,
+ torch_dtype: Dtype = torch.float16):
+ data = MODEL_MAPPER.get(model_type)
+ if data is None:
+ raise ValueError(f'model_type: {model_type}')
+ model_id = data['model_id']
+ revision = data.get('revision', 'master')
+ get_function = data.get('get_function', get_model_tokenizer_default)
+ ignore_file_pattern = data.get('ignore_file_pattern', [])
+ model_dir = snapshot_download(
+ model_id, revision, ignore_file_pattern=ignore_file_pattern)
+ model, tokenizer = get_function(model_dir, load_model, add_special_token,
+ torch_dtype)
+ return model, tokenizer, model_dir
diff --git a/examples/pytorch/llm/utils/utils.py b/examples/pytorch/llm/utils/utils.py
new file mode 100644
index 00000000..5b8ee163
--- /dev/null
+++ b/examples/pytorch/llm/utils/utils.py
@@ -0,0 +1,321 @@
+import datetime as dt
+import math
+import os
+import random
+import re
+import sys
+from dataclasses import dataclass, field
+from functools import partial
+from types import MethodType
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+from datasets import Dataset as HfDataset
+from numpy import ndarray
+from tensorboard.backend.event_processing.event_accumulator import \
+ EventAccumulator
+from torch import Tensor
+from torch import device as Device
+from torch import dtype as Dtype
+from torch.nn import Module
+from torch.nn.utils.rnn import pad_sequence
+from torchmetrics import Accuracy, MeanMetric
+from tqdm import tqdm
+from transformers import GenerationConfig, TextStreamer
+
+from modelscope import get_logger
+from modelscope.metrics.base import Metric
+from modelscope.metrics.builder import METRICS
+from modelscope.swift import LoRAConfig, Swift
+from modelscope.trainers import EpochBasedTrainer
+from modelscope.utils.config import Config, ConfigDict
+from modelscope.utils.registry import default_group
+
+COLOR, COLOR_S = '#FFE2D9', '#FF7043'
+
+DEFAULT_PROMPT = """Here's a conversation between a human and an AI assistant. \
+The AI assistant provides detailed, friendly answers for the human.
+
+### Human:
+{instruction}
+
+### AI:
+"""
+
+logger = get_logger()
+os.environ['TOKENIZERS_PARALLELISM'] = 'true'
+
+
+def _get_version(work_dir: str) -> int:
+ if os.path.isdir(work_dir):
+ fnames = os.listdir(work_dir)
+ else:
+ fnames = []
+ v_list = [-1]
+ for fname in fnames:
+ m = re.match(r'v(\d+)', fname)
+ if m is None:
+ continue
+ v = m.group(1)
+ v_list.append(int(v))
+ return max(v_list) + 1
+
+
+def get_work_dir(work_dir: str) -> str:
+ """add version"""
+ work_dir = os.path.abspath(work_dir)
+ version = _get_version(work_dir)
+ time = dt.datetime.now().strftime('%Y%m%d-%H%M%S')
+
+ work_dir = os.path.join(work_dir, f'v{version}-{time}')
+ logger.info(f'work_dir: {work_dir}')
+ return work_dir
+
+
+def seed_everything(seed: Optional[int] = None, gpu_dtm: bool = False) -> int:
+ if seed is None:
+ seed_max = np.iinfo(np.int32).max
+ seed = random.randint(0, seed_max)
+
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ logger.info(f'Global seed set to {seed}')
+ if gpu_dtm:
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+ logger.info(f'Setting deterministic: {True}, benchmark: {False}')
+ return seed
+
+
+def get_T_max(dataset_len: int, batch_size: int, max_epochs: int,
+ drop_last: bool) -> int:
+ """Calculate T_max in CosineAnnealingLR"""
+ if drop_last:
+ T_max = dataset_len // batch_size
+ else:
+ T_max = math.ceil(dataset_len / batch_size)
+ T_max *= max_epochs
+ return T_max
+
+
+def tokenize_function(example: Dict[str, Optional[str]],
+ tokenizer,
+ prompt: str = DEFAULT_PROMPT,
+ max_length: Optional[int] = 2048) -> Dict[str, Any]:
+ instruction: str = example['instruction']
+ output = example.get('output')
+ src_text = prompt.format(instruction=instruction)
+ src_input_ids: List[int] = tokenizer(
+ src_text, return_attention_mask=False,
+ add_special_tokens=True)['input_ids']
+
+ tgt_input_ids = []
+ if output is not None:
+ tgt_input_ids += tokenizer(
+ output, return_attention_mask=False,
+ add_special_tokens=False)['input_ids']
+ tgt_input_ids += [tokenizer.eos_token_id]
+ labels = [-100] * len(src_input_ids) + tgt_input_ids
+ else:
+ labels = None
+ input_ids = src_input_ids + tgt_input_ids
+
+ if max_length is not None:
+ input_ids = input_ids[-max_length:]
+ if labels is not None:
+ labels = labels[-max_length:]
+
+ return {'input_ids': input_ids, 'labels': labels}
+
+
+def stat_dataset(dataset: HfDataset) -> None:
+ """Statistical analysis was performed on the dataset"""
+ _token_len = []
+ for d in dataset:
+ _token_len.append(len(d['input_ids']))
+ _token_len = np.array(_token_len)
+ mean = _token_len.mean().item()
+ std = _token_len.std().item()
+ min_ = _token_len.min().item()
+ max_ = _token_len.max().item()
+ logger.info(
+ f'Dataset Token Length: {mean:.6f}±{std:.6f}, min={min_:.6f}, max={max_:.6f}, size={_token_len.shape[0]}'
+ )
+
+
+def print_example(example: Dict[str, Any], tokenizer) -> None:
+ input_ids, labels = example['input_ids'], example['labels']
+ print(f'[INPUT_IDS] {input_ids}')
+ print(f'[INPUT] {tokenizer.decode(input_ids)}')
+ print()
+ print(f'[LABLES_IDS] {labels}')
+ print(
+ f'[LABLES] {tokenizer.decode([lb if lb != -100 else 0 for lb in labels])}'
+ )
+
+
+def data_collate_fn(batch: List[Dict[str, Any]], tokenizer) -> Dict[str, Any]:
+ input_ids = [torch.tensor(b['input_ids']) for b in batch]
+ labels = [torch.tensor(b['labels']) for b in batch]
+ attention_mask = [
+ torch.ones(len(input_ids[i]), dtype=torch.int64)
+ for i in range(len(input_ids))
+ ]
+
+ input_ids = pad_sequence(
+ input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
+ attention_mask = pad_sequence(
+ attention_mask, batch_first=True, padding_value=0)
+ labels = pad_sequence(labels, batch_first=True, padding_value=-100)
+ return {
+ 'input_ids': input_ids,
+ 'attention_mask': attention_mask,
+ 'labels': labels
+ }
+
+
+def print_model_info(model: Module, name: Optional[str] = None) -> None:
+ if name is None:
+ name = model.__class__.__name__
+
+ n_params = sum(p.numel() for p in model.parameters())
+ n_grads = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ n_buffers = sum(p.numel() for p in model.buffers())
+
+ n_params /= 1e6
+ n_grads /= 1e6
+ n_buffers /= 1e6
+ s = [
+ f'{name}: ',
+ f'{n_params:.4f}M Params ({n_grads:.4f}M Trainable), ',
+ f'{n_buffers:.4f}M Buffers',
+ ]
+ s += '.'
+ logger.info(''.join(s))
+
+
+def show_freeze_layers(model: Module, max_lines: int = 20) -> None:
+ named_p = list(model.named_parameters())
+ for i, (n, p) in enumerate(named_p):
+ if i >= max_lines:
+ logger.info('...')
+ break
+ logger.info(f'{n}: requires_grad={p.requires_grad}')
+
+
+@METRICS.register_module(group_key=default_group, module_name='my_metric')
+class MyMetric(Metric):
+
+ def __init__(self, vocab_size: int):
+ self.acc = Accuracy('multiclass', num_classes=vocab_size)
+ self.loss = MeanMetric()
+
+ def add(self, outputs: Dict[str, Any], inputs: Dict[str, Any]) -> None:
+ loss: Tensor = outputs.loss
+ self.loss.update(loss.cpu())
+
+ labels: Tensor = inputs['labels']
+ labels = labels[:, 1:]
+ labels_mask = labels != -100
+ logits: Tensor = outputs.logits[:, :-1]
+ logits = logits[labels_mask].contiguous().view(-1, logits.shape[-1])
+ pred = logits.argmax(dim=-1)
+ labels = labels[labels_mask].to(logits.device)
+ self.acc.update(pred.cpu(), labels.cpu())
+
+ def evaluate(self):
+ return {
+ 'acc': self.acc.compute().item(),
+ 'loss': self.loss.compute().item()
+ }
+
+ def merge(self, other: 'MyMetric') -> None:
+ """This script does not support ddp. TODO"""
+ raise NotImplementedError
+
+
+Item = Dict[str, float]
+
+
+def read_tensorboard_file(fpath: str) -> Dict[str, List[Item]]:
+ if not os.path.isfile(fpath):
+ raise FileNotFoundError(f'fpath: {fpath}')
+ ea = EventAccumulator(fpath)
+ ea.Reload()
+ res = {}
+ tags = ea.Tags()['scalars']
+ for tag in tags:
+ values = ea.Scalars(tag)
+ r = []
+ for v in values:
+ r.append({'step': v.step, 'value': v.value})
+ res[tag] = r
+ return res
+
+
+def tensorboard_smoothing(values: List[float],
+ smooth: float = 0.9) -> List[float]:
+ norm_factor = 1
+ x = 0
+ res = []
+ for i in range(len(values)):
+ x = x * smooth + values[i] # Exponential decay
+ res.append(x / norm_factor)
+
+ norm_factor *= smooth
+ norm_factor += 1
+ return res
+
+
+def plot_images(tb_dir: str,
+ smooth_key: List[str],
+ smooth_val: float = 0.9,
+ figsize: Tuple[int, int] = (8, 5),
+ dpi: int = 100) -> None:
+ images_dir = os.path.join(os.path.dirname(tb_dir), 'images')
+ os.makedirs(images_dir, exist_ok=True)
+
+ fname = os.listdir(tb_dir)[0]
+ tb_path = os.path.join(tb_dir, fname)
+ data = read_tensorboard_file(tb_path)
+
+ for k in data.keys():
+ _data = data[k]
+ steps = [d['step'] for d in _data]
+ values = [d['value'] for d in _data]
+ if len(values) == 0:
+ continue
+ _, ax = plt.subplots(1, 1, squeeze=True, figsize=figsize, dpi=dpi)
+ ax.set_title(k)
+ if len(values) == 1:
+ ax.scatter(steps, values, color=COLOR_S)
+ elif k in smooth_key:
+ ax.plot(steps, values, color=COLOR)
+ values_s = tensorboard_smoothing(values, smooth_val)
+ ax.plot(steps, values_s, color=COLOR_S)
+ else:
+ ax.plot(steps, values, color=COLOR_S)
+ fpath = os.path.join(images_dir, k.replace('/', '_'))
+ plt.savefig(fpath, dpi=dpi, bbox_inches='tight')
+
+
+def inference(input_ids: List[int],
+ model,
+ tokenizer,
+ streamer: Optional[TextStreamer] = None,
+ generation_config: Optional[GenerationConfig] = None,
+ tag: str = '[INFERENCE]') -> str:
+ print(f'{tag}{tokenizer.decode(input_ids)}', end='')
+ input_ids = torch.tensor(input_ids)[None].cuda()
+ attention_mask = torch.ones_like(input_ids)
+ generate_ids = model.generate(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ streamer=streamer,
+ generation_config=generation_config)
+ output_text = tokenizer.decode(generate_ids[0])
+ return output_text
diff --git a/examples/pytorch/llm_agent/_common.py b/examples/pytorch/llm_agent/_common.py
new file mode 100644
index 00000000..dd07ef31
--- /dev/null
+++ b/examples/pytorch/llm_agent/_common.py
@@ -0,0 +1,426 @@
+import ast
+import datetime as dt
+import math
+import os
+import random
+import re
+import sys
+from functools import partial
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import json
+import matplotlib.pyplot as plt
+import numpy as np
+#
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from matplotlib.axes import Axes
+from matplotlib.figure import Figure
+from numpy import ndarray
+from tensorboard.backend.event_processing.event_accumulator import \
+ EventAccumulator
+from torch import Tensor
+from torch import device as Device
+from torch import dtype as Dtype
+from torch.nn import Module
+from torch.nn.parameter import Parameter
+from torch.nn.utils.rnn import pad_sequence
+from torch.optim import Optimizer
+from torch.optim import lr_scheduler as lrs
+from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
+from torch.utils.data import Dataset
+#
+from torchmetrics import Accuracy, MeanMetric
+#
+from tqdm import tqdm
+
+#
+from modelscope import (Model, MsDataset, get_logger, read_config,
+ snapshot_download)
+from modelscope.metrics.base import Metric
+from modelscope.metrics.builder import METRICS
+from modelscope.models.nlp.chatglm2 import ChatGLM2Tokenizer
+from modelscope.msdatasets.dataset_cls.custom_datasets import \
+ TorchCustomDataset
+from modelscope.swift import LoRAConfig, Swift
+from modelscope.trainers import EpochBasedTrainer
+from modelscope.utils.config import Config, ConfigDict
+from modelscope.utils.registry import default_group
+
+#
+PROMPT = """System: {system}
+Human: {user}
+AI: """
+MAX_LENGTH = 2048
+TEST_MAX_LENGTH = MAX_LENGTH
+
+COLOR, COLOR_S = '#FFE2D9', '#FF7043'
+logger = get_logger()
+#
+
+
+def _get_version(work_dir: str) -> int:
+ if os.path.isdir(work_dir):
+ fnames = os.listdir(work_dir)
+ else:
+ fnames = []
+ v_list = [-1]
+ for fname in fnames:
+ m = re.match(r'v(\d+)', fname)
+ if m is None:
+ continue
+ v = m.group(1)
+ v_list.append(int(v))
+ return max(v_list) + 1
+
+
+def get_work_dir(work_dir: str) -> str:
+ """add version"""
+ work_dir = os.path.abspath(work_dir)
+ version = _get_version(work_dir)
+ time = dt.datetime.now().strftime('%Y%m%d-%H%M%S')
+ #
+ work_dir = os.path.join(work_dir, f'v{version}-{time}')
+ logger.info(f'work_dir: {work_dir}')
+ return work_dir
+
+
+def _format_device(device: Union[List[int], str]) -> Tuple[List[int], str]:
+ if isinstance(device, list):
+ device_ids = device
+ device_str = ','.join([str(d) for d in device])
+ else:
+ device_ids = [int(d) for d in device.split(',') if d != '-1']
+ device_str = device
+ device_str = device_str.replace(' ', '')
+ return device_ids, device_str
+
+
+def select_device(device: Union[List[int], str]) -> Device:
+ """Call this function before cuda is initialized.
+ device: e.g. []: 'cpu', [0], [0, 1, 2]
+ e.g. '-1': 'cpu', '0', '0,1,2'
+ """
+ if torch.cuda.is_initialized():
+ logger.warning('CUDA has been initialized! Device selection fails!')
+ return torch.device('cuda:0')
+ #
+ device_ids, device_str = _format_device(device)
+ #
+ os.environ['CUDA_VISIBLE_DEVICES'] = device_str
+ log_s = 'Using device: '
+ if len(device_ids) == 0:
+ master_device: str = 'cpu'
+ log_s += 'cpu'
+ else:
+ assert torch.cuda.is_available(
+ ) and torch.cuda.device_count() >= len(device_ids)
+ master_device = 'cuda:0'
+ log_s += f'cuda:{device_str}'
+ logger.info(log_s)
+ return torch.device(master_device)
+
+
+def seed_everything(seed: Optional[int] = None, gpu_dtm: bool = False) -> int:
+ if seed is None:
+ seed_max = np.iinfo(np.int32).max
+ seed = random.randint(0, seed_max)
+
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ logger.info(f'Global seed set to {seed}')
+ if gpu_dtm:
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+ logger.info(f'Setting deterministic: {True}, benchmark: {False}')
+ return seed
+
+
+def get_T_max(dataset_len: int, batch_size: int, max_epochs: int,
+ drop_last: bool) -> int:
+ """Calculate T_max in CosineAnnealingLR"""
+ if drop_last:
+ T_max = dataset_len // batch_size
+ else:
+ T_max = math.ceil(dataset_len / batch_size)
+ T_max *= max_epochs
+ return T_max
+
+
+def tokenize_function(system: str, user: str, assistant: Optional[str],
+ tokenizer) -> Dict[str, Any]:
+ """Only applicable to baichuan and chatglm2. Other models need to be tested"""
+ src_text = PROMPT.format(system=system, user=user)
+ src_input_ids: List[int] = tokenizer(
+ src_text, return_attention_mask=False,
+ add_special_tokens=True)['input_ids']
+ #
+ tgt_input_ids: List[int] = []
+ if assistant is not None:
+ tgt_input_ids += tokenizer(
+ assistant, return_attention_mask=False,
+ add_special_tokens=False)['input_ids']
+ tgt_input_ids += [tokenizer.eos_token_id]
+ labels = [-100] * len(src_input_ids) + tgt_input_ids
+ else:
+ labels = None
+ input_ids = src_input_ids + tgt_input_ids
+ #
+ if assistant is not None:
+ if len(input_ids) > MAX_LENGTH:
+ return {}
+ else:
+ input_ids = input_ids[-TEST_MAX_LENGTH:]
+ #
+ return {'input_ids': input_ids, 'labels': labels}
+
+
+class MyDataset(TorchCustomDataset):
+
+ def __init__(self, system: List[str], user: List[str],
+ assistant: List[str], tokenize_function) -> None:
+ self._data = []
+ for i in tqdm(range(len(system))):
+ _d = tokenize_function(system[i], user[i], assistant[i])
+ if len(_d) == 0:
+ continue
+ self._data.append(_d)
+
+ def __getitem__(self, idx: int) -> Dict[str, Any]:
+ return self._data[idx]
+
+ def __len__(self) -> int:
+ return len(self._data)
+
+
+def stat_dataset(dataset: 'MyDataset') -> None:
+ """Statistical analysis was performed on the data set"""
+ _token_len = []
+ for d in dataset:
+ _token_len.append(len(d['input_ids']))
+ _token_len = np.array(_token_len)
+ mean = _token_len.mean().item()
+ std = _token_len.std().item()
+ min_ = _token_len.min().item()
+ max_ = _token_len.max().item()
+ logger.info(
+ f'Dataset Token Length: {mean:.6f}±{std:.6f}, min={min_:.6f}, max={max_:.6f}, size={_token_len.shape[0]}'
+ )
+
+
+def print_examples(examples: Dict[str, Any], tokenizer) -> None:
+ input_ids, labels = examples['input_ids'], examples['labels']
+ print(f'[INPUT_IDS] {tokenizer.decode(input_ids)}')
+ print()
+ print(
+ f'[LABLES] {tokenizer.decode([lb if lb != -100 else 0 for lb in labels])}'
+ )
+
+
+def data_collate_fn(batch: List[Dict[str, Any]], tokenizer) -> Dict[str, Any]:
+ input_ids = [torch.tensor(b['input_ids']) for b in batch]
+ labels = [torch.tensor(b['labels']) for b in batch]
+ attention_mask = [
+ torch.ones(len(input_ids[i]), dtype=torch.int64)
+ for i in range(len(input_ids))
+ ]
+ #
+ input_ids = pad_sequence(
+ input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
+ attention_mask = pad_sequence(
+ attention_mask, batch_first=True, padding_value=0)
+ labels = pad_sequence(labels, batch_first=True, padding_value=-100)
+ return {
+ 'input_ids': input_ids,
+ 'attention_mask': attention_mask,
+ 'labels': labels
+ }
+
+
+def print_model_info(model: Module, name: Optional[str] = None) -> None:
+ if name is None:
+ name = model.__class__.__name__
+ #
+ n_params = sum(p.numel() for p in model.parameters())
+ n_grads = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ n_buffers = sum(p.numel() for p in model.buffers())
+ #
+ n_params /= 1e6
+ n_grads /= 1e6
+ n_buffers /= 1e6
+ s = [
+ f'{name}: ',
+ f'{n_params:.4f}M Params ({n_grads:.4f}M Trainable), ',
+ f'{n_buffers:.4f}M Buffers',
+ ]
+ s += '.'
+ logger.info(''.join(s))
+
+
+def show_freeze_layers(model: Module, max_lines: int = 20) -> None:
+ named_p = list(model.named_parameters())
+ for i, (n, p) in enumerate(named_p):
+ if i >= max_lines:
+ logger.info('...')
+ break
+ logger.info(f'{n}: requires_grad={p.requires_grad}')
+
+
+@METRICS.register_module(group_key=default_group, module_name='my_metric')
+class MyMetric(Metric):
+
+ def __init__(self, vocab_size: int):
+ self.acc = Accuracy('multiclass', num_classes=vocab_size)
+ self.loss = MeanMetric()
+
+ def add(self, outputs: Dict[str, Any], inputs: Dict[str, Any]) -> None:
+ loss: Tensor = outputs.loss
+ self.loss.update(loss)
+ #
+ labels: Tensor = inputs['labels']
+ labels = labels[:, 1:]
+ labels_mask = labels != -100
+ logits: Tensor = outputs.logits[:, :-1]
+ logits = logits[labels_mask].contiguous().view(-1, logits.shape[-1])
+ pred = logits.argmax(dim=-1)
+ labels = labels[labels_mask].to(logits.device)
+ self.acc.update(pred, labels)
+
+ def evaluate(self):
+ return {
+ 'acc': self.acc.compute().item(),
+ 'loss': self.loss.compute().item()
+ }
+
+ def merge(self, other: 'MyMetric') -> None:
+ """This script does not support ddp"""
+ raise NotImplementedError
+
+
+def _add_special_token(tokenizer):
+ if tokenizer.eos_token_id is None:
+ tokenizer.eos_token_id = 2
+ if tokenizer.bos_token_id is None:
+ tokenizer.bos_token_id = 1
+ if tokenizer.pad_token_id is None:
+ tokenizer.pad_token_id = 0
+ logger.info(f'bos_token_id: {tokenizer.bos_token_id}, '
+ f'eos_token_id: {tokenizer.eos_token_id}, '
+ f'pad_token_id: {tokenizer.pad_token_id}')
+
+
+def get_baichuan7B_model_tokenizer(model_dir: str,
+ load_model: bool = True,
+ add_special_token: bool = True):
+ sys.path.insert(0, model_dir)
+ from configuration_baichuan import BaiChuanConfig
+ from tokenization_baichuan import BaiChuanTokenizer
+ from modeling_baichuan import BaiChuanForCausalLM
+ model_config = BaiChuanConfig.from_pretrained(model_dir)
+ model_config.torch_dtype = torch.float16
+ logger.info(f'model_config: {model_config}')
+ tokenizer = BaiChuanTokenizer.from_pretrained(model_dir)
+ model = None
+ if load_model:
+ model = BaiChuanForCausalLM.from_pretrained(
+ model_dir,
+ config=model_config,
+ device_map='auto',
+ torch_dtype=torch.float16)
+ #
+ if add_special_token:
+ _add_special_token(tokenizer)
+ return model, tokenizer
+
+
+def get_chatglm2_model_tokenizer(model_dir: str,
+ load_model: bool = True,
+ add_special_token: bool = True):
+ config = read_config(model_dir)
+ config['model'] = ConfigDict({'type': 'chatglm2-6b'})
+ tokenizer = ChatGLM2Tokenizer.from_pretrained(model_dir)
+ model = None
+ if load_model:
+ model = Model.from_pretrained(
+ model_dir,
+ cfg_dict=config,
+ device_map='auto',
+ torch_dtype=torch.float16)
+ if add_special_token:
+ _add_special_token(tokenizer)
+ return model, tokenizer
+
+
+def make_dataset(
+ split: str, tokenize_function: Callable[[str, str, Optional[str]],
+ Dict[str, Any]]
+) -> MyDataset:
+ """
+ split: Literal['train', 'validation']
+ """
+ dataset = MsDataset.load(
+ 'modelscope/ms_hackathon_23_agent_train_dev', split=split)
+ system = []
+ user = []
+ assistant = []
+ for d in dataset:
+ content = ast.literal_eval(d['conversations'])
+ s = content[0]['value']
+ assert len(content) % 2 == 1
+ for i in range(len(content) // 2):
+ system.append(s)
+ user.append(content[2 * i + 1]['value'])
+ assistant.append(content[2 * i + 2]['value'])
+ return MyDataset(system, user, assistant, tokenize_function)
+
+
+Item = Dict[str, float]
+
+
+def read_tensorboard_file(fpath: str) -> Dict[str, List[Item]]:
+ if not os.path.isfile(fpath):
+ raise FileNotFoundError(f'fpath: {fpath}')
+ ea = EventAccumulator(fpath)
+ ea.Reload()
+ res = {}
+ tags = ea.Tags()['scalars']
+ for tag in tags:
+ values = ea.Scalars(tag)
+ r = []
+ for v in values:
+ r.append({'step': v.step, 'value': v.value})
+ res[tag] = r
+ return res
+
+
+def tensorboard_smoothing(values: List[float],
+ smooth: float = 0.9) -> List[float]:
+ norm_factor = 1
+ x = 0
+ res = []
+ for i in range(len(values)):
+ x = x * smooth + values[i] # Exponential decay
+ res.append(x / norm_factor)
+ #
+ norm_factor *= smooth
+ norm_factor += 1
+ return res
+
+
+def plot_image(data: Dict[str, List[Item]], key_name: str,
+ smooth: float) -> Figure:
+ _data = data[key_name]
+ steps = [d['step'] for d in _data]
+ values = [d['value'] for d in _data]
+ fig, ax = plt.subplots(1, 1, squeeze=True, figsize=(8, 5), dpi=100)
+ ax.set_title(key_name)
+ if smooth != 0:
+ ax.plot(steps, values, color=COLOR)
+ values_s = tensorboard_smoothing(values, smooth)
+ ax.plot(steps, values_s, color=COLOR_S)
+ else:
+ ax.plot(steps, values, color=COLOR_S)
+ return fig
diff --git a/examples/pytorch/llm_agent/baichuan_infer.ipynb b/examples/pytorch/llm_agent/baichuan_infer.ipynb
new file mode 100644
index 00000000..7ef29951
--- /dev/null
+++ b/examples/pytorch/llm_agent/baichuan_infer.ipynb
@@ -0,0 +1,482 @@
+{
+ "cells": [
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Baichuan 推理"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 配置实验环境"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[2023-07-02 22:28:00,199] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2023-07-02 22:28:00,675 - modelscope - INFO - PyTorch version 2.0.1 Found.\n",
+ "2023-07-02 22:28:00,676 - modelscope - INFO - Loading ast index from /home/hackathon/.cache/modelscope/ast_indexer\n",
+ "2023-07-02 22:28:00,700 - modelscope - INFO - Loading done! Current index file version is 1.6.2, with md5 ddf811ee982377c1357284a2bfda3dec and a total number of 861 components indexed\n",
+ "2023-07-02 22:28:01,367 - modelscope - INFO - [0, 1]\n",
+ "2023-07-02 22:28:01,512 - modelscope - INFO - Using device: cuda:0,1\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "device(type='cuda', index=0)"
+ ]
+ },
+ "execution_count": 1,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from _common import *\n",
+ "from transformers import TextStreamer\n",
+ "device_ids = [0, 1]\n",
+ "select_device(device_ids)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 导入Model, Tokenizer\n",
+ "Note: 你需要设置CKPT_FPATH的内容, 指向`.bin`文件, 或`.pth`文件"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2023-07-02 22:28:03,375 - modelscope - INFO - Model revision not specified, use default: master in development mode\n",
+ "2023-07-02 22:28:03,375 - modelscope - INFO - Development mode use revision: master\n",
+ "2023-07-02 22:28:03,695 - modelscope - INFO - model_config: BaiChuanConfig {\n",
+ " \"architectures\": [\n",
+ " \"BaiChuanForCausalLM\"\n",
+ " ],\n",
+ " \"auto_map\": {\n",
+ " \"AutoConfig\": \"configuration_baichuan.BaiChuanConfig\",\n",
+ " \"AutoModelForCausalLM\": \"modeling_baichuan.BaiChuanForCausalLM\"\n",
+ " },\n",
+ " \"bos_token_id\": 1,\n",
+ " \"eos_token_id\": 2,\n",
+ " \"hidden_act\": \"silu\",\n",
+ " \"hidden_size\": 4096,\n",
+ " \"initializer_range\": 0.02,\n",
+ " \"intermediate_size\": 11008,\n",
+ " \"max_position_embeddings\": 4096,\n",
+ " \"model_type\": \"baichuan\",\n",
+ " \"num_attention_heads\": 32,\n",
+ " \"num_hidden_layers\": 32,\n",
+ " \"pad_token_id\": 0,\n",
+ " \"rms_norm_eps\": 1e-06,\n",
+ " \"tie_word_embeddings\": false,\n",
+ " \"torch_dtype\": \"float16\",\n",
+ " \"transformers_version\": \"4.30.2\",\n",
+ " \"use_cache\": true,\n",
+ " \"vocab_size\": 64000\n",
+ "}\n",
+ "\n",
+ "The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "BaiChuanForCausalLM(\n",
+ " (model): Model(\n",
+ " (embed_tokens): Embedding(64000, 4096, padding_idx=0)\n",
+ " (layers): ModuleList(\n",
+ " (0-31): 32 x DecoderLayer(\n",
+ " (self_attn): Attention(\n",
+ " (W_pack): Linear(in_features=4096, out_features=12288, bias=False)\n",
+ " (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
+ " (rotary_emb): RotaryEmbedding()\n",
+ " )\n",
+ " (mlp): MLP(\n",
+ " (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)\n",
+ " (down_proj): Linear(in_features=11008, out_features=4096, bias=False)\n",
+ " (up_proj): Linear(in_features=4096, out_features=11008, bias=False)\n",
+ " (act_fn): SiLUActivation()\n",
+ " )\n",
+ " (input_layernorm): RMSNorm()\n",
+ " (post_attention_layernorm): RMSNorm()\n",
+ " )\n",
+ " )\n",
+ " (norm): RMSNorm()\n",
+ " )\n",
+ " (lm_head): Linear(in_features=4096, out_features=64000, bias=False)\n",
+ ")"
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "CKPT_FAPTH = '/home/hackathon/my_git/agent/runs/baichuan/v10-20230702-172449/output_best/pytorch_model.bin'\n",
+ "LORA_TARGET_MODULES = ['W_pack']\n",
+ "\n",
+ "model_dir = snapshot_download('baichuan-inc/baichuan-7B', 'v1.0.5')\n",
+ "model, tokenizer = get_baichuan7B_model_tokenizer(model_dir)\n",
+ "model.bfloat16() # Consistent with training"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 导入Lora"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2023-07-02 22:28:14,108 - modelscope - INFO - lora_config: LoRAConfig(rank=8, replace_modules=['W_pack'], lora_alpha=32, lora_dropout=0, merge_weights=True, use_merged_linear=False, enable_lora=None, fan_in_fan_out=False, bias='none', only_lora_trainable=True, pretrained_weights='/home/hackathon/my_git/agent/runs/baichuan/v10-20230702-172449/output_best/pytorch_model.bin')\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "BaiChuanForCausalLM(\n",
+ " (model): Model(\n",
+ " (embed_tokens): Embedding(64000, 4096, padding_idx=0)\n",
+ " (layers): ModuleList(\n",
+ " (0-31): 32 x DecoderLayer(\n",
+ " (self_attn): Attention(\n",
+ " (W_pack): Linear(in_features=4096, out_features=12288, bias=False)\n",
+ " (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
+ " (rotary_emb): RotaryEmbedding()\n",
+ " )\n",
+ " (mlp): MLP(\n",
+ " (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)\n",
+ " (down_proj): Linear(in_features=11008, out_features=4096, bias=False)\n",
+ " (up_proj): Linear(in_features=4096, out_features=11008, bias=False)\n",
+ " (act_fn): SiLUActivation()\n",
+ " )\n",
+ " (input_layernorm): RMSNorm()\n",
+ " (post_attention_layernorm): RMSNorm()\n",
+ " )\n",
+ " )\n",
+ " (norm): RMSNorm()\n",
+ " )\n",
+ " (lm_head): Linear(in_features=4096, out_features=64000, bias=False)\n",
+ ")"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "LORA_RANK = 8\n",
+ "LORA_ALPHA = 32\n",
+ "LORA_DROPOUT_P = 0 # Arbitrary value\n",
+ "lora_config = LoRAConfig(\n",
+ " replace_modules=LORA_TARGET_MODULES,\n",
+ " rank=LORA_RANK,\n",
+ " lora_alpha=LORA_ALPHA,\n",
+ " lora_dropout=LORA_DROPOUT_P,\n",
+ " pretrained_weights=CKPT_FAPTH)\n",
+ "logger.info(f'lora_config: {lora_config}')\n",
+ "Swift.prepare_model(model, lora_config)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 导入Dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2023-07-02 22:28:28,832 - modelscope - INFO - No subset_name specified, defaulting to the default\n",
+ "2023-07-02 22:28:29,317 - modelscope - WARNING - Reusing dataset ms_hackathon_23_agent_train_dev (/home/hackathon/.cache/modelscope/hub/datasets/modelscope/ms_hackathon_23_agent_train_dev/master/data_files)\n",
+ "2023-07-02 22:28:29,318 - modelscope - INFO - Generating dataset ms_hackathon_23_agent_train_dev (/home/hackathon/.cache/modelscope/hub/datasets/modelscope/ms_hackathon_23_agent_train_dev/master/data_files)\n",
+ "2023-07-02 22:28:29,318 - modelscope - INFO - Reusing cached meta-data file: /home/hackathon/.cache/modelscope/hub/datasets/modelscope/ms_hackathon_23_agent_train_dev/master/data_files/941b733ec0354c2172a3386d8788bb37\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "682dc9eedfce4092a25fcadc977c794a",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Downloading data files: 0it [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "8e53d79d8e4845618231f3afb5bc096f",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Extracting data files: 0it [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 285/285 [00:00<00:00, 1566679.74it/s]\n"
+ ]
+ }
+ ],
+ "source": [
+ "test_dataset = make_dataset('validation', lambda system, user, assistant:\n",
+ " {'system': system, 'user': user, 'assistant': assistant})"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 推理"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TEST] 你是达摩院的ModelScopeGPT(魔搭助手),你是个大语言模型, 是2023年达摩院的工程师训练得到的。你有多种能力,可以通过插件集成魔搭社区的模型api来回复用户的问题,还能解答用户使用模型遇到的问题和模型知识相关问答。1. {\"plugin_name\": \"modelscope_speech-generation\", \"plugin_owner\": \"ModelScopeGPT\", \"plugin_type\": \"default\", \"plugin_schema_for_model\": {\"name\": \"modelscope_speech-generation\", \"description\": \"针对回复的内容,用语音表示,同时可以选择是男声或者女声\", \"url\": \"http://90.49.118.175:2603/\", \"paths\": [{\"name\": \"modelscope_speech-generation\", \"model_id\": \"/damo/speech_sambert-hifigan_tts_zh-cn_16k\", \"method\": \"post\", \"description\": \"针对回复的内容,用语音表示,同时可以选择是男声或者女声\", \"parameters\": [{\"name\": \"text\", \"description\": \"要转成语音的文本\", \"required\": \"True\"}, {\"name\": \"gender\", \"description\": \"用户身份\", \"required\": \"True\"}]}]}}\n",
+ "\n",
+ "2. {\"plugin_name\": \"modelscope_speech-generation\", \"plugin_owner\": \"ModelScopeGPT\", \"plugin_type\": \"default\", \"plugin_schema_for_model\": {\"name\": \"modelscope_speech-generation\", \"description\": \"针对回复的内容,用语音表示,同时可以选择是男声或者女声\", \"url\": \"http://132.94.116.115:5983/\", \"paths\": [{\"name\": \"modelscope_speech-generation\", \"model_id\": \"/damo/speech_sambert-hifigan_tts_zh-cn_16k\", \"method\": \"post\", \"description\": \"针对回复的内容,用语音表示,同时可以选择是男声或者女声\", \"parameters\": [{\"name\": \"text\", \"description\": \"要转成语音的文本\", \"required\": \"True\"}, {\"name\": \"gender\", \"description\": \"用户身份\", \"required\": \"True\"}]}]}}\n",
+ "\n",
+ "3. {\"plugin_name\": \"modelscope_speech-generation\", \"plugin_owner\": \"ModelScopeGPT\", \"plugin_type\": \"default\", \"plugin_schema_for_model\": {\"name\": \"modelscope_speech-generation\", \"description\": \"针对回复的内容,用语音表示,同时可以选择是男声或者女声\", \"url\": \"http://94.43.176.75:1062/\", \"paths\": [{\"name\": \"modelscope_speech-generation\", \"model_id\": \"/damo/speech_sambert-hifigan_tts_zh-cn_16k\", \"method\": \"post\", \"description\": \"针对回复的内容,用语音表示,同时可以选择是男声或者女声\", \"parameters\": [{\"name\": \"text\", \"description\": \"要转成语音的文本\", \"required\": \"True\"}, {\"name\": \"gender\", \"description\": \"用户身份\", \"required\": \"True\"}]}]}} \n",
+ "\n",
+ "### 用户\n",
+ "生成一首诗歌,主题为“秋天的美景”,读出来这段话 \n",
+ "\n",
+ "### 助手\n",
+ "秋天,是一个美丽的季节,是一个收获的季节,是一个充满诗意的季节。秋天的天空,湛蓝湛蓝的,像一块蓝宝石;秋天的田野,金黄色的稻谷,像一片金色的海洋;秋天的果园,硕果累累,像一幅美丽的画卷。秋天的山林,层林尽染,像一幅色彩斑斓的油画;秋天的河流,清澈见底,像一条银色的丝带。秋天的天空,湛蓝湛蓝的,像一块蓝宝石;秋天的田野,金黄色的稻谷,像一片金色的海洋;秋天的果园,硕果累累,像一幅美丽的画卷。秋天的山林,层林尽染,像一幅色彩斑斓的油画;秋天的河流,清澈见底,像一条银色的丝带。\n",
+ "\n",
+ "[LABELS]秋树红叶舞飘零,\n",
+ "山间小溪水潺潺。\n",
+ "微风拂面感清凉,\n",
+ "散步赏景心旷神怡。\n",
+ "<|startofthink|>```JSON\n",
+ "{\"api_name\": \"modelscope_speech-generation\", \"url\": \"http://90.49.118.175:2603/damo/speech_sambert-hifigan_tts_zh-cn_16k\", \"parameters\": {\"text\": \"秋树红叶舞飘零,\n",
+ "山间小溪水潺潺。\n",
+ "微风拂面感清凉,\n",
+ "散步赏景心旷神怡。\", \"gender\": \"woman\"}}\n",
+ "```<|endofthink|>\n",
+ "\n",
+ "<|startofexec|>```JSON\n",
+ "{\"result\": \"\"}\n",
+ "```<|endofexec|>\n",
+ "\n",
+ "-----------------------------------------------------------------------------------\n",
+ "[TEST] 你是达摩院的ModelScopeGPT(魔搭助手),你是个大语言模型, 是2023年达摩院的工程师训练得到的。你有多种能力,可以通过插件集成魔搭社区的模型api来回复用户的问题,还能解答用户使用模型遇到的问题和模型知识相关问答。1. {\"plugin_name\": \"modelscope_text-address\", \"plugin_owner\": \"ModelScopeGPT\", \"plugin_type\": \"default\", \"plugin_schema_for_model\": {\"name\": \"modelscope_text-address\", \"description\": \"针对中文的地址信息,识别出里面的元素,包括省、市、区、镇、社区、道路、路号、POI、楼栋号、户室号等\", \"url\": \"http://159.1.4.174:3210/\", \"paths\": [{\"name\": \"modelscope_text-address\", \"model_id\": \"/damo/mgeo_geographic_elements_tagging_chinese_base\", \"method\": \"post\", \"description\": \"针对中文的地址信息,识别出里面的元素,包括省、市、区、镇、社区、道路、路号、POI、楼栋号、户室号等\", \"parameters\": [{\"name\": \"text\", \"description\": \"用户输入的地址信息\", \"required\": \"True\"}]}]}}\n",
+ "\n",
+ "2. {\"plugin_name\": \"modelscope_text-address\", \"plugin_owner\": \"ModelScopeGPT\", \"plugin_type\": \"default\", \"plugin_schema_for_model\": {\"name\": \"modelscope_text-address\", \"description\": \"针对中文的地址信息,识别出里面的元素,包括省、市、区、镇、社区、道路、路号、POI、楼栋号、户室号等\", \"url\": \"http://172.163.158.154:5325/\", \"paths\": [{\"name\": \"modelscope_text-address\", \"model_id\": \"/damo/mgeo_geographic_elements_tagging_chinese_base\", \"method\": \"post\", \"description\": \"针对中文的地址信息,识别出里面的元素,包括省、市、区、镇、社区、道路、路号、POI、楼栋号、户室号等\", \"parameters\": [{\"name\": \"text\", \"description\": \"用户输入的地址信息\", \"required\": \"True\"}]}]}}\n",
+ "\n",
+ "3. {\"plugin_name\": \"modelscope_text-address\", \"plugin_owner\": \"ModelScopeGPT\", \"plugin_type\": \"default\", \"plugin_schema_for_model\": {\"name\": \"modelscope_text-address\", \"description\": \"针对中文的地址信息,识别出里面的元素,包括省、市、区、镇、社区、道路、路号、POI、楼栋号、户室号等\", \"url\": \"http://133.94.12.37:3160/\", \"paths\": [{\"name\": \"modelscope_text-address\", \"model_id\": \"/damo/mgeo_geographic_elements_tagging_chinese_base\", \"method\": \"post\", \"description\": \"针对中文的地址信息,识别出里面的元素,包括省、市、区、镇、社区、道路、路号、POI、楼栋号、户室号等\", \"parameters\": [{\"name\": \"text\", \"description\": \"用户输入的地址信息\", \"required\": \"True\"}]}]}} \n",
+ "\n",
+ "### 用户\n",
+ "现在我给你另一条地址,请识别出里面的元素。输入地址:广东省深圳市南山区科技园北区 \n",
+ "\n",
+ "### 助手\n",
+ "<|startofthink|>```JSON\n",
+ "{\"api_name\": \"modelscope_text-address\", \"url\": \"http://133.94.12.37:3160/damo/mgeo_geographic_elements_tagging_chinese_base\", \"parameters\": {\"text\": \"广东省深圳市南山区科技园北区\"}}\n",
+ "```<|endofthink|>\n",
+ "\n",
+ "<|startofexec|>```JSON\n",
+ "{\"prov\": \"广东省\", \"city\": \"深圳市\", \"district\": \"南山区\", \"community\": \"科技园北区\"}\n",
+ "```<|endofexec|>\n",
+ "地址识别json表示:{\"prov\": \"广东省\", \"city\": \"深圳市\", \"district\": \"南山区\", \"community\": \"科技园北区\"}。我使用的模型是ModelScope的'damo/mgeo_geographic_elements_tagging_chinese_base'模型。这是基于达摩院联合高德发布的多任务多模态地址预训练底座MGeo模型微调得到的。\n",
+ "\n",
+ "[LABELS]<|startofthink|>```JSON\n",
+ "{\"api_name\": \"modelscope_text-address\", \"url\": \"http://159.1.4.174:3210/damo/mgeo_geographic_elements_tagging_chinese_base\", \"parameters\": {\"text\": \"广东省深圳市南山区科技园北区\"}}\n",
+ "```<|endofthink|>\n",
+ "\n",
+ "<|startofexec|>```JSON\n",
+ "{\"prov\": \"广东省\", \"city\": \"深圳市\", \"district\": \"南山区\", \"town\": \"\", \"community\": \"科技园北区\", \"poi\": \"\"}\n",
+ "```<|endofexec|>\n",
+ "地址识别json表示:{\"prov\": \"广东省\", \"city\": \"深圳市\", \"district\": \"南山区\", \"town\": \"\", \"community\": \"科技园北区\", \"poi\": \"\"}。我使用的模型是ModelScope的'damo/mgeo_geographic_elements_tagging_chinese_base'模型。这是基于达摩院联合高德发布的多任务多模态地址预训练底座MGeo模型微调得到的。\n",
+ "-----------------------------------------------------------------------------------\n",
+ "[TEST] 你是达摩院的ModelScopeGPT(魔搭助手),你是个大语言模型, 是2023年达摩院的工程师训练得到的。你有多种能力,可以通过插件集成魔搭社区的模型api来回复用户的问题,还能解答用户使用模型遇到的问题和模型知识相关问答。目前支持的插件信息如下,请自行判断是否需要调用插件来解决当前用户问题。若需要调用插件,则需要将插件调用请求按照json格式给出,必须包含api_name、url、parameters字段,并在其前后使用<|startofthink|>和<|endofthink|>作为标志。然后你需要根据插件API调用结果生成合理的答复;若无需调用插件,则直接给出对应回复即可:\n",
+ "\n",
+ "1. {\"name\": \"modelscope_text-translation-zh2en\", \"description\": \"将输入的中文文本翻译成英文\", \"url\": \"http://api-inference.modelscope.cn/api-inference/v1/models\", \"paths\": [{\"name\": \"modelscope_text-translation-zh2en\", \"model_id\": \"/damo/nlp_csanmt_translation_zh2en\", \"method\": \"post\", \"description\": \"将输入的中文文本翻译成英文\", \"parameters\": [{\"name\": \"text\", \"description\": \"用户输入的中文文本\", \"required\": \"True\"}]}]}\n",
+ "\n",
+ "2. {\"name\": \"modelscope_speech-generation\", \"description\": \"针对回复的内容,用语音表示,同时可以选择是男声或者女声\", \"url\": \"http://api-inference.modelscope.cn/api-inference/v1/models\", \"paths\": [{\"name\": \"modelscope_speech-generation\", \"model_id\": \"/damo/speech_sambert-hifigan_tts_zh-cn_16k\", \"method\": \"post\", \"description\": \"针对回复的内容,用语音表示,同时可以选择是男声或者女声\", \"parameters\": [{\"name\": \"text\", \"description\": \"要转成语音的文本\", \"required\": \"True\"}, {\"name\": \"gender\", \"description\": \"用户身份\", \"required\": \"True\"}]}]}\n",
+ "\n",
+ "3. {\"name\": \"modelscope_image-generation\", \"description\": \"针对文本输入,生成对应的图片\", \"url\": \"http://api-inference.modelscope.cn/api-inference/v1/models\", \"paths\": [{\"name\": \"modelscope_image-generation\", \"model_id\": \"/damo/image_generation\", \"method\": \"post\", \"description\": \"针对文本输入,生成对应的图片\", \"parameters\": [{\"name\": \"text\", \"description\": \"用户输入的文本信息\", \"required\": \"True\"}]}]} \n",
+ "\n",
+ "### 用户\n",
+ "歌手:古巨基\n",
+ "歌曲名:爱情马戏班\n",
+ "经典歌词:情是何等诡秘能令人使出看家把戏;恋爱就像走纲线般惊险;为你献技 像马戏班\n",
+ "请结合以上信息,编写一个智能音响的播放导语,需要有文采,字数30字以内,凸显一下即将播放该歌曲 \n",
+ "\n",
+ "### 助手\n",
+ "古巨基的《爱情马戏班》,是一首经典的情歌,歌词中充满了对爱情的向往和对爱情的渴望,让人不禁沉醉其中。这首歌的旋律优美动听,歌词朗朗上口,让人听了就忍不住跟着哼唱。\n",
+ "\n",
+ "[LABELS]亲爱的主人,今天我为您带来的是古巨基的经典之作——《爱情马戏班》。这首歌曲描绘了情与爱的神秘和惊险,让人们为之倾倒。让我们一起享受这场爱情的马戏表演吧!\n",
+ "-----------------------------------------------------------------------------------\n",
+ "[TEST] 你是达摩院的ModelScopeGPT(魔搭助手),你是个大语言模型, 是2023年达摩院的工程师训练得到的。你有多种能力,可以通过插件集成魔搭社区的模型api来回复用户的问题,还能解答用户使用模型遇到的问题和模型知识相关问答。1. {\"plugin_name\": \"modelscope_text-ie\", \"plugin_owner\": \"ModelScopeGPT\", \"plugin_type\": \"default\", \"plugin_schema_for_model\": {\"name\": \"modelscope_text-ie\", \"description\": \"针对中文的文本,根据schema要抽取的内容,找出其中对应信息,并用json格式展示\", \"url\": \"http://114.42.178.183:8005/\", \"paths\": [{\"name\": \"modelscope_text-ie\", \"model_id\": \"/damo/nlp_structbert_siamese-uie_chinese-base\", \"method\": \"post\", \"description\": \"针对中文的文本,根据schema要抽取的内容,找出其中对应信息,并用json格式展示\", \"parameters\": [{\"name\": \"text\", \"description\": \"用户输入的文本\", \"required\": \"True\"}, {\"name\": \"schema\", \"description\": \"要抽取信息的json表示\", \"required\": \"True\"}]}]}}\n",
+ "\n",
+ "2. {\"plugin_name\": \"modelscope_text-ie\", \"plugin_owner\": \"ModelScopeGPT\", \"plugin_type\": \"default\", \"plugin_schema_for_model\": {\"name\": \"modelscope_text-ie\", \"description\": \"针对中文的文本,根据schema要抽取的内容,找出其中对应信息,并用json格式展示\", \"url\": \"http://93.82.87.89:6631/\", \"paths\": [{\"name\": \"modelscope_text-ie\", \"model_id\": \"/damo/nlp_structbert_siamese-uie_chinese-base\", \"method\": \"post\", \"description\": \"针对中文的文本,根据schema要抽取的内容,找出其中对应信息,并用json格式展示\", \"parameters\": [{\"name\": \"text\", \"description\": \"用户输入的文本\", \"required\": \"True\"}, {\"name\": \"schema\", \"description\": \"要抽取信息的json表示\", \"required\": \"True\"}]}]}}\n",
+ "\n",
+ "3. {\"plugin_name\": \"modelscope_text-ie\", \"plugin_owner\": \"ModelScopeGPT\", \"plugin_type\": \"default\", \"plugin_schema_for_model\": {\"name\": \"modelscope_text-ie\", \"description\": \"针对中文的文本,根据schema要抽取的内容,找出其中对应信息,并用json格式展示\", \"url\": \"http://4.105.93.165:8143/\", \"paths\": [{\"name\": \"modelscope_text-ie\", \"model_id\": \"/damo/nlp_structbert_siamese-uie_chinese-base\", \"method\": \"post\", \"description\": \"针对中文的文本,根据schema要抽取的内容,找出其中对应信息,并用json格式展示\", \"parameters\": [{\"name\": \"text\", \"description\": \"用户输入的文本\", \"required\": \"True\"}, {\"name\": \"schema\", \"description\": \"要抽取信息的json表示\", \"required\": \"True\"}]}]}} \n",
+ "\n",
+ "### 用户\n",
+ "按照给定的schema抽取出下面文本对应的信息\n",
+ "schema:{\"动物\": null, \"食物\": null, \"颜色\": null}\n",
+ "这只棕色的狗狗很喜欢吃狗粮。 \n",
+ "\n",
+ "### 助手\n",
+ "<|startofthink|>```JSON\n",
+ "{\"api_name\": \"modelscope_text-ie\", \"url\": \"http://4.105.93.165:8143/damo/nlp_structbert_siamese-uie_chinese-base\", \"parameters\": {\"text\": \"这只棕色的狗狗很喜欢吃狗粮。\", \"schema\": \"{\\\"动物\\\": null, \\\"食物\\\": null, \\\"颜色\\\": null}\"}}\n",
+ "```<|endofthink|>\n",
+ "\n",
+ "<|startofexec|>```JSON\n",
+ "{\"动物\": [\"棕色的狗狗\"], \"食物\": [\"狗粮\"], \"颜色\": [\"棕色\"]}\n",
+ "```<|endofexec|>\n",
+ "信息抽取结果:{\"动物\": [\"棕色的狗狗\"], \"食物\": [\"狗粮\"], \"颜色\": [\"棕色\"]}。我使用的模型是ModelScope的'damo/nlp_structbert_siamese-uie_chinese-base'模型。这是一个基于StructBERT预训练模型微调训练的通用信息抽取模型。\n",
+ "\n",
+ "[LABELS]<|startofthink|>```JSON\n",
+ "{\"api_name\": \"modelscope_text-ie\", \"url\": \"http://114.42.178.183:8005/damo/nlp_structbert_siamese-uie_chinese-base\", \"parameters\": {\"text\": \"这只棕色的狗狗很喜欢吃狗粮。\", \"schema\": \"{\\\"动物\\\": null, \\\"食物\\\": null, \\\"颜色\\\": null}\"}}\n",
+ "```<|endofthink|>\n",
+ "\n",
+ "<|startofexec|>```JSON\n",
+ "{\"动物\": [\"狗狗\"], \"食物\": [\"狗粮\"], \"颜色\": [\"棕色\"]}\n",
+ "```<|endofexec|>\n",
+ "信息抽取结果:{\"动物\": [\"狗狗\"], \"食物\": [\"狗粮\"], \"颜色\": [\"棕色\"]}。我使用的模型是ModelScope的'damo/nlp_structbert_siamese-uie_chinese-base'模型。这是一个基于StructBERT预训练模型微调训练的通用信息抽取模型。\n",
+ "-----------------------------------------------------------------------------------\n",
+ "[TEST] 你是达摩院的ModelScopeGPT(魔搭助手),你是个大语言模型, 是2023年达摩院的工程师训练得到的。你有多种能力,可以通过插件集成魔搭社区的模型api来回复用户的问题,还能解答用户使用模型遇到的问题和模型知识相关问答。1. {\"plugin_name\": \"modelscope_text-ie\", \"plugin_owner\": \"ModelScopeGPT\", \"plugin_type\": \"default\", \"plugin_schema_for_model\": {\"name\": \"modelscope_text-ie\", \"description\": \"针对中文的文本,根据schema要抽取的内容,找出其中对应信息,并用json格式展示\", \"url\": \"http://28.179.171.5:6428/\", \"paths\": [{\"name\": \"modelscope_text-ie\", \"model_id\": \"/damo/nlp_structbert_siamese-uie_chinese-base\", \"method\": \"post\", \"description\": \"针对中文的文本,根据schema要抽取的内容,找出其中对应信息,并用json格式展示\", \"parameters\": [{\"name\": \"text\", \"description\": \"用户输入的文本\", \"required\": \"True\"}, {\"name\": \"schema\", \"description\": \"要抽取信息的json表示\", \"required\": \"True\"}]}]}}\n",
+ "\n",
+ "2. {\"plugin_name\": \"modelscope_text-ie\", \"plugin_owner\": \"ModelScopeGPT\", \"plugin_type\": \"default\", \"plugin_schema_for_model\": {\"name\": \"modelscope_text-ie\", \"description\": \"针对中文的文本,根据schema要抽取的内容,找出其中对应信息,并用json格式展示\", \"url\": \"http://100.111.18.38:6408/\", \"paths\": [{\"name\": \"modelscope_text-ie\", \"model_id\": \"/damo/nlp_structbert_siamese-uie_chinese-base\", \"method\": \"post\", \"description\": \"针对中文的文本,根据schema要抽取的内容,找出其中对应信息,并用json格式展示\", \"parameters\": [{\"name\": \"text\", \"description\": \"用户输入的文本\", \"required\": \"True\"}, {\"name\": \"schema\", \"description\": \"要抽取信息的json表示\", \"required\": \"True\"}]}]}}\n",
+ "\n",
+ "3. {\"plugin_name\": \"modelscope_text-ie\", \"plugin_owner\": \"ModelScopeGPT\", \"plugin_type\": \"default\", \"plugin_schema_for_model\": {\"name\": \"modelscope_text-ie\", \"description\": \"针对中文的文本,根据schema要抽取的内容,找出其中对应信息,并用json格式展示\", \"url\": \"http://144.67.18.142:6381/\", \"paths\": [{\"name\": \"modelscope_text-ie\", \"model_id\": \"/damo/nlp_structbert_siamese-uie_chinese-base\", \"method\": \"post\", \"description\": \"针对中文的文本,根据schema要抽取的内容,找出其中对应信息,并用json格式展示\", \"parameters\": [{\"name\": \"text\", \"description\": \"用户输入的文本\", \"required\": \"True\"}, {\"name\": \"schema\", \"description\": \"要抽取信息的json表示\", \"required\": \"True\"}]}]}} \n",
+ "\n",
+ "### 用户\n",
+ "按照给定的schema抽取出下面文本对应的信息\n",
+ "schema:{\"人物\": null, \"地理位置\": null, \"组织机构\": null}\n",
+ "谷歌公司是一家全球知名的科技公司,总部位于美国的加利福尼亚州山景市。 \n",
+ "\n",
+ "### 助手\n",
+ "<|startofthink|>```JSON\n",
+ "{\"api_name\": \"modelscope_text-ie\", \"url\": \"http://144.67.18.142:6381/damo/nlp_structbert_siamese-uie_chinese-base\", \"parameters\": {\"text\": \"谷歌公司是一家全球知名的科技公司,总部位于美国的加利福尼亚州山景市。\", \"schema\": \"{\\\"人物\\\": null, \\\"地理位置\\\": null, \\\"组织机构\\\": null}\"}}\n",
+ "```<|endofthink|>\n",
+ "\n",
+ "<|startofexec|>```JSON\n",
+ "{\"人物\": [\"谷歌公司\"], \"地理位置\": [\"美国\"], \"组织机构\": [\"科技公司\"]}\n",
+ "```<|endofexec|>\n",
+ "信息抽取结果:{\"人物\": [\"谷歌公司\"], \"地理位置\": [\"美国\"], \"组织机构\": [\"科技公司\"]}。我使用的模型是ModelScope的'damo/nlp_structbert_siamese-uie_chinese-base'模型。这是一个基于StructBERT预训练模型微调训练的通用信息抽取模型。\n",
+ "\n",
+ "[LABELS]<|startofthink|>```JSON\n",
+ "{\"api_name\": \"modelscope_text-ie\", \"url\": \"http://100.111.18.38:6408/damo/nlp_structbert_siamese-uie_chinese-base\", \"parameters\": {\"text\": \"谷歌公司是一家全球知名的科技公司,总部位于美国的加利福尼亚州山景市。\", \"schema\": \"{\\\"人物\\\": null, \\\"地理位置\\\": null, \\\"组织机构\\\": null}\"}}\n",
+ "```<|endofthink|>\n",
+ "\n",
+ "<|startofexec|>```JSON\n",
+ "{\"人物\": [], \"地理位置\": [\"美国\", \"加利福尼亚州山景市\"], \"组织机构\": [\"谷歌公司\"]}\n",
+ "```<|endofexec|>\n",
+ "信息抽取结果:{\"人物\": [], \"地理位置\": [\"美国\", \"加利福尼亚州山景市\"], \"组织机构\": [\"谷歌公司\"]}。我使用的模型是ModelScope的'damo/nlp_structbert_siamese-uie_chinese-base'模型。这是一个基于StructBERT预训练模型微调训练的通用信息抽取模型。\n",
+ "-----------------------------------------------------------------------------------\n"
+ ]
+ }
+ ],
+ "source": [
+ "streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)\n",
+ "for d in test_dataset[:5]:\n",
+ " system = d['system']\n",
+ " user = d['user']\n",
+ " assistant = d['assistant']\n",
+ " input_ids = tokenize_function(system, user, None, tokenizer)['input_ids']\n",
+ " print(f'[TEST]{tokenizer.decode(input_ids)}', end='')\n",
+ " input_ids = torch.tensor(input_ids)[None].cuda()\n",
+ " attention_mask = torch.ones_like(input_ids)\n",
+ " generate_ids = model.generate(input_ids=input_ids, max_new_tokens=512,\n",
+ " attention_mask=attention_mask,\n",
+ " streamer=streamer, pad_token_id=tokenizer.eos_token_id, \n",
+ " temperature=0.7, top_k=50, top_p=0.7, do_sample=True)\n",
+ " print()\n",
+ " print(f'[LABELS]{assistant}')\n",
+ " print('-----------------------------------------------------------------------------------')\n",
+ " # input('next[ENTER]')"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/examples/pytorch/llm_agent/baichuan_sft.ipynb b/examples/pytorch/llm_agent/baichuan_sft.ipynb
new file mode 100644
index 00000000..6c41ff25
--- /dev/null
+++ b/examples/pytorch/llm_agent/baichuan_sft.ipynb
@@ -0,0 +1,1814 @@
+{
+ "cells": [
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Baichuan + Lora + Agent\n",
+ "baichuan-7B是由百川智能开发的一个开源的大规模预训练模型。基于Transformer结构,在大约1.2万亿tokens上训练的70亿参数模型,支持中英双语,上下文窗口长度为4096。在标准的中文和英文权威benchmark(C-EVAL/MMLU)上均取得同尺寸最好的效果。"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "1. Ref: https://modelscope.cn/models/baichuan-inc/baichuan-7B/summary\n",
+ "2. 以下脚本可以在2*A10环境下正常运行, 大概占用40G显存\n",
+ "3. python>=3.8"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 配置实验环境"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# !pip install modelscope\n",
+ "# !pip install numpy pandas matplotlib scikit-learn\n",
+ "# !pip install transformers datasets\n",
+ "# !conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia\n",
+ "# !pip install tqdm tensorboard torchmetrics sentencepiece charset_normalizer accelerate\n",
+ "\n",
+ "# !pip install numpy -U # Resolve torchmetrics dependencies and update numpy"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[2023-07-02 17:24:09,391] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/hackathon/miniconda3/envs/hackathon/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+ " from .autonotebook import tqdm as notebook_tqdm\n",
+ "2023-07-02 17:24:09,870 - modelscope - INFO - PyTorch version 2.0.1 Found.\n",
+ "2023-07-02 17:24:09,871 - modelscope - INFO - Loading ast index from /home/hackathon/.cache/modelscope/ast_indexer\n",
+ "2023-07-02 17:24:09,895 - modelscope - INFO - Loading done! Current index file version is 1.6.2, with md5 ddf811ee982377c1357284a2bfda3dec and a total number of 861 components indexed\n",
+ "2023-07-02 17:24:10,570 - modelscope - INFO - [0, 1]\n",
+ "2023-07-02 17:24:10,719 - modelscope - INFO - Using device: cuda:0,1\n",
+ "2023-07-02 17:24:10,720 - modelscope - INFO - Global seed set to 42\n"
+ ]
+ }
+ ],
+ "source": [
+ "from _common import *\n",
+ "device_ids = [0, 1]\n",
+ "select_device(device_ids)\n",
+ "_ = seed_everything(42)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 导入Model, Tokenizer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2023-07-02 17:24:11,036 - modelscope - INFO - Model revision not specified, use default: master in development mode\n",
+ "2023-07-02 17:24:11,037 - modelscope - INFO - Development mode use revision: master\n",
+ "2023-07-02 17:24:11,364 - modelscope - INFO - model_config: BaiChuanConfig {\n",
+ " \"architectures\": [\n",
+ " \"BaiChuanForCausalLM\"\n",
+ " ],\n",
+ " \"auto_map\": {\n",
+ " \"AutoConfig\": \"configuration_baichuan.BaiChuanConfig\",\n",
+ " \"AutoModelForCausalLM\": \"modeling_baichuan.BaiChuanForCausalLM\"\n",
+ " },\n",
+ " \"bos_token_id\": 1,\n",
+ " \"eos_token_id\": 2,\n",
+ " \"hidden_act\": \"silu\",\n",
+ " \"hidden_size\": 4096,\n",
+ " \"initializer_range\": 0.02,\n",
+ " \"intermediate_size\": 11008,\n",
+ " \"max_position_embeddings\": 4096,\n",
+ " \"model_type\": \"baichuan\",\n",
+ " \"num_attention_heads\": 32,\n",
+ " \"num_hidden_layers\": 32,\n",
+ " \"pad_token_id\": 0,\n",
+ " \"rms_norm_eps\": 1e-06,\n",
+ " \"tie_word_embeddings\": false,\n",
+ " \"torch_dtype\": \"float16\",\n",
+ " \"transformers_version\": \"4.30.2\",\n",
+ " \"use_cache\": true,\n",
+ " \"vocab_size\": 64000\n",
+ "}\n",
+ "\n",
+ "The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.\n"
+ ]
+ }
+ ],
+ "source": [
+ "WORK_DIR = 'runs/baichuan'\n",
+ "LORA_TARGET_MODULES = ['W_pack']\n",
+ "#\n",
+ "model_dir = snapshot_download('baichuan-inc/baichuan-7B', 'v1.0.5')\n",
+ "model, tokenizer = get_baichuan7B_model_tokenizer(model_dir)\n",
+ "#\n",
+ "GRADIENT_CHECKPOINTING = True\n",
+ "if GRADIENT_CHECKPOINTING:\n",
+ " model.gradient_checkpointing_enable()\n",
+ " model.enable_input_require_grads()"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 准备Lora"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2023-07-02 17:24:21,741 - modelscope - INFO - lora_config: LoRAConfig(rank=8, replace_modules=['W_pack'], lora_alpha=32, lora_dropout=0.1, merge_weights=True, use_merged_linear=False, enable_lora=None, fan_in_fan_out=False, bias='none', only_lora_trainable=True, pretrained_weights=None)\n",
+ "2023-07-02 17:24:36,360 - modelscope - INFO - model.embed_tokens.weight: requires_grad=False\n",
+ "2023-07-02 17:24:36,360 - modelscope - INFO - model.layers.0.self_attn.W_pack.weight: requires_grad=False\n",
+ "2023-07-02 17:24:36,361 - modelscope - INFO - model.layers.0.self_attn.W_pack.lora_A: requires_grad=True\n",
+ "2023-07-02 17:24:36,361 - modelscope - INFO - model.layers.0.self_attn.W_pack.lora_B: requires_grad=True\n",
+ "2023-07-02 17:24:36,361 - modelscope - INFO - model.layers.0.self_attn.o_proj.weight: requires_grad=False\n",
+ "2023-07-02 17:24:36,362 - modelscope - INFO - model.layers.0.mlp.gate_proj.weight: requires_grad=False\n",
+ "2023-07-02 17:24:36,362 - modelscope - INFO - model.layers.0.mlp.down_proj.weight: requires_grad=False\n",
+ "2023-07-02 17:24:36,363 - modelscope - INFO - model.layers.0.mlp.up_proj.weight: requires_grad=False\n",
+ "2023-07-02 17:24:36,363 - modelscope - INFO - model.layers.0.input_layernorm.weight: requires_grad=False\n",
+ "2023-07-02 17:24:36,363 - modelscope - INFO - model.layers.0.post_attention_layernorm.weight: requires_grad=False\n",
+ "2023-07-02 17:24:36,363 - modelscope - INFO - model.layers.1.self_attn.W_pack.weight: requires_grad=False\n",
+ "2023-07-02 17:24:36,364 - modelscope - INFO - model.layers.1.self_attn.W_pack.lora_A: requires_grad=True\n",
+ "2023-07-02 17:24:36,364 - modelscope - INFO - model.layers.1.self_attn.W_pack.lora_B: requires_grad=True\n",
+ "2023-07-02 17:24:36,364 - modelscope - INFO - model.layers.1.self_attn.o_proj.weight: requires_grad=False\n",
+ "2023-07-02 17:24:36,364 - modelscope - INFO - model.layers.1.mlp.gate_proj.weight: requires_grad=False\n",
+ "2023-07-02 17:24:36,365 - modelscope - INFO - model.layers.1.mlp.down_proj.weight: requires_grad=False\n",
+ "2023-07-02 17:24:36,365 - modelscope - INFO - model.layers.1.mlp.up_proj.weight: requires_grad=False\n",
+ "2023-07-02 17:24:36,365 - modelscope - INFO - model.layers.1.input_layernorm.weight: requires_grad=False\n",
+ "2023-07-02 17:24:36,365 - modelscope - INFO - model.layers.1.post_attention_layernorm.weight: requires_grad=False\n",
+ "2023-07-02 17:24:36,365 - modelscope - INFO - model.layers.2.self_attn.W_pack.weight: requires_grad=False\n",
+ "2023-07-02 17:24:36,366 - modelscope - INFO - ...\n",
+ "2023-07-02 17:24:36,368 - modelscope - INFO - BaiChuanForCausalLM: 7004.7539M Params (4.1943M Trainable), 33.5565M Buffers.\n",
+ "2023-07-02 17:24:36,370 - modelscope - INFO - device: cuda:0, dtype: torch.float16\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "BaiChuanForCausalLM(\n",
+ " (model): Model(\n",
+ " (embed_tokens): Embedding(64000, 4096, padding_idx=0)\n",
+ " (layers): ModuleList(\n",
+ " (0-31): 32 x DecoderLayer(\n",
+ " (self_attn): Attention(\n",
+ " (W_pack): Linear(\n",
+ " in_features=4096, out_features=12288, bias=False\n",
+ " (lora_dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
+ " (rotary_emb): RotaryEmbedding()\n",
+ " )\n",
+ " (mlp): MLP(\n",
+ " (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)\n",
+ " (down_proj): Linear(in_features=11008, out_features=4096, bias=False)\n",
+ " (up_proj): Linear(in_features=4096, out_features=11008, bias=False)\n",
+ " (act_fn): SiLUActivation()\n",
+ " )\n",
+ " (input_layernorm): RMSNorm()\n",
+ " (post_attention_layernorm): RMSNorm()\n",
+ " )\n",
+ " )\n",
+ " (norm): RMSNorm()\n",
+ " )\n",
+ " (lm_head): Linear(in_features=4096, out_features=64000, bias=False)\n",
+ ")"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "LORA_RANK = 8\n",
+ "LORA_ALPHA = 32\n",
+ "LORA_DROPOUT_P = 0.1\n",
+ "lora_config = LoRAConfig(\n",
+ " replace_modules=LORA_TARGET_MODULES,\n",
+ " rank=LORA_RANK,\n",
+ " lora_alpha=LORA_ALPHA,\n",
+ " lora_dropout=LORA_DROPOUT_P)\n",
+ "logger.info(f'lora_config: {lora_config}')\n",
+ "Swift.prepare_model(model, lora_config)\n",
+ "#\n",
+ "show_freeze_layers(model)\n",
+ "print_model_info(model)\n",
+ "_p = list(model.parameters())[100]\n",
+ "logger.info(f'device: {_p.device}, dtype: {_p.dtype}')\n",
+ "model.bfloat16()"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 导入Dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 5036/5036 [00:12<00:00, 398.82it/s]\n",
+ "100%|██████████| 285/285 [00:00<00:00, 383.15it/s]\n",
+ "2023-07-02 17:24:49,863 - modelscope - INFO - Dataset Token Length: 958.649707±371.357483, min=44.000000, max=2045.000000, size=4953\n",
+ "2023-07-02 17:24:49,864 - modelscope - INFO - Dataset Token Length: 993.447653±337.821458, min=75.000000, max=1946.000000, size=277\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[INPUT_IDS] 你是达摩院的ModelScopeGPT(魔搭助手),你是个大语言模型, 是2023年达摩院的工程师训练得到的。你有多种能力,可以通过插件集成魔搭社区的模型api来回复用户的问题,还能解答用户使用模型遇到的问题和模型知识相关问答。1. {\"plugin_name\": \"modelscope_text-ie\", \"plugin_owner\": \"ModelScopeGPT\", \"plugin_type\": \"default\", \"plugin_schema_for_model\": {\"name\": \"modelscope_text-ie\", \"description\": \"针对中文的文本,根据schema要抽取的内容,找出其中对应信息,并用json格式展示\", \"url\": \"http://109.199.101.10:1485/\", \"paths\": [{\"name\": \"modelscope_text-ie\", \"model_id\": \"/damo/nlp_structbert_siamese-uie_chinese-base\", \"method\": \"post\", \"description\": \"针对中文的文本,根据schema要抽取的内容,找出其中对应信息,并用json格式展示\", \"parameters\": [{\"name\": \"text\", \"description\": \"用户输入的文本\", \"required\": \"True\"}, {\"name\": \"schema\", \"description\": \"要抽取信息的json表示\", \"required\": \"True\"}]}]}}\n",
+ "\n",
+ "2. {\"plugin_name\": \"modelscope_text-ie\", \"plugin_owner\": \"ModelScopeGPT\", \"plugin_type\": \"default\", \"plugin_schema_for_model\": {\"name\": \"modelscope_text-ie\", \"description\": \"针对中文的文本,根据schema要抽取的内容,找出其中对应信息,并用json格式展示\", \"url\": \"http://9.32.64.200:5873/\", \"paths\": [{\"name\": \"modelscope_text-ie\", \"model_id\": \"/damo/nlp_structbert_siamese-uie_chinese-base\", \"method\": \"post\", \"description\": \"针对中文的文本,根据schema要抽取的内容,找出其中对应信息,并用json格式展示\", \"parameters\": [{\"name\": \"text\", \"description\": \"用户输入的文本\", \"required\": \"True\"}, {\"name\": \"schema\", \"description\": \"要抽取信息的json表示\", \"required\": \"True\"}]}]}}\n",
+ "\n",
+ "3. {\"plugin_name\": \"modelscope_text-ie\", \"plugin_owner\": \"ModelScopeGPT\", \"plugin_type\": \"default\", \"plugin_schema_for_model\": {\"name\": \"modelscope_text-ie\", \"description\": \"针对中文的文本,根据schema要抽取的内容,找出其中对应信息,并用json格式展示\", \"url\": \"http://54.149.78.185:3979/\", \"paths\": [{\"name\": \"modelscope_text-ie\", \"model_id\": \"/damo/nlp_structbert_siamese-uie_chinese-base\", \"method\": \"post\", \"description\": \"针对中文的文本,根据schema要抽取的内容,找出其中对应信息,并用json格式展示\", \"parameters\": [{\"name\": \"text\", \"description\": \"用户输入的文本\", \"required\": \"True\"}, {\"name\": \"schema\", \"description\": \"要抽取信息的json表示\", \"required\": \"True\"}]}]}} \n",
+ "\n",
+ "### 用户\n",
+ "按照给定的schema抽取出下面文本对应的信息\n",
+ "schema:{\"人物\": null, \"地理位置\": null, \"组织机构\": null}\n",
+ "近日,美国政府宣布将对中国1000多种商品加征关税,并威胁进一步加征关税。 \n",
+ "\n",
+ "### 助手\n",
+ " <|startofthink|>```JSON\n",
+ "{\"api_name\": \"modelscope_text-ie\", \"url\": \"http://9.32.64.200:5873/damo/nlp_structbert_siamese-uie_chinese-base\", \"parameters\": {\"text\": \"近日,美国政府宣布将对中国1000多种商品加征关税,并威胁进一步加征关税。\", \"schema\": \"{\\\"人物\\\": null, \\\"地理位置\\\": null, \\\"组织机构\\\": null}\"}}\n",
+ "```<|endofthink|>\n",
+ "\n",
+ "<|startofexec|>```JSON\n",
+ "{\"人物\": [], \"地理位置\": [\"中国\", \"美国\"], \"组织机构\": []}\n",
+ "```<|endofexec|>\n",
+ "信息抽取结果:{\"人物\": [], \"地理位置\": [\"中国\", \"美国\"], \"组织机构\": []}。我使用的模型是ModelScope的'damo/nlp_structbert_siamese-uie_chinese-base'模型。这是一个基于StructBERT预训练模型微调训练的通用信息抽取模型。\n",
+ "\n",
+ "[LABLES]