From 4ca937d2bad2899003d7b6deb3f94f7e8212dee1 Mon Sep 17 00:00:00 2001 From: Jintao Date: Wed, 26 Jul 2023 18:12:55 +0800 Subject: [PATCH] support openbuddy-llama2-13b (#416) --- examples/pytorch/llm/_parser.py | 69 ++++++ examples/pytorch/llm/llm_infer.py | 77 ++++--- examples/pytorch/llm/llm_sft.py | 94 ++++---- examples/pytorch/llm/run_infer.sh | 8 +- examples/pytorch/llm/run_sft.sh | 7 +- examples/pytorch/llm/utils/__init__.py | 5 + examples/pytorch/llm/utils/dataset.py | 72 ++++++ examples/pytorch/llm/utils/models.py | 133 +++++++++++ .../llm/{_common.py => utils/utils.py} | 216 +----------------- 9 files changed, 385 insertions(+), 296 deletions(-) create mode 100644 examples/pytorch/llm/_parser.py create mode 100644 examples/pytorch/llm/utils/__init__.py create mode 100644 examples/pytorch/llm/utils/dataset.py create mode 100644 examples/pytorch/llm/utils/models.py rename examples/pytorch/llm/{_common.py => utils/utils.py} (54%) diff --git a/examples/pytorch/llm/_parser.py b/examples/pytorch/llm/_parser.py new file mode 100644 index 00000000..480cfdce --- /dev/null +++ b/examples/pytorch/llm/_parser.py @@ -0,0 +1,69 @@ +import os +from dataclasses import dataclass, field +from typing import List, Optional, Tuple, Type, TypeVar, Union + +import torch +from torch import device as Device +from transformers import HfArgumentParser + +from modelscope import get_logger + +logger = get_logger() + + +def _format_device(device: Union[List[int], str]) -> Tuple[List[int], str]: + if isinstance(device, list): + device_ids = device + device_str = ','.join([str(d) for d in device]) + else: + device_ids = [int(d) for d in device.split(',') if d != '-1'] + device_str = device + device_str = device_str.replace(' ', '') + return device_ids, device_str + + +def select_device(device: Union[List[int], str]) -> Device: + """Call this function before cuda is initialized. + device: e.g. []: 'cpu', [0], [0, 1, 2] + e.g. '-1': 'cpu', '0', '0,1,2' + """ + if torch.cuda.is_initialized(): + logger.warning('CUDA has been initialized! Device selection fails!') + return torch.device('cuda:0') + + device_ids, device_str = _format_device(device) + os.environ['CUDA_VISIBLE_DEVICES'] = device_str + log_s = 'Using device: ' + if len(device_ids) == 0: + master_device: str = 'cpu' + log_s += 'cpu' + else: + assert torch.cuda.is_available( + ) and torch.cuda.device_count() >= len(device_ids) + master_device = 'cuda:0' + log_s += f'cuda:{device_str}' + logger.info(log_s) + return torch.device(master_device) + + +_T = TypeVar('_T') + + +def parse_args(class_type: Type[_T], + argv: Optional[List[str]] = None) -> Tuple[_T, List[str]]: + parser = HfArgumentParser([class_type]) + args, remaining_args = parser.parse_args_into_dataclasses( + argv, return_remaining_strings=True) + logger.info(f'args: {args}') + return args, remaining_args + + +@dataclass +class DeviceArguments: + device: str = '0' # e.g. '-1'; '0'; '0,1' + + +def parse_device(argv: Optional[List[str]] = None) -> List[str]: + args, remaining_args = parse_args(DeviceArguments, argv) + select_device(args.device) + return remaining_args diff --git a/examples/pytorch/llm/llm_infer.py b/examples/pytorch/llm/llm_infer.py index 8b9c1bb1..614e3d36 100644 --- a/examples/pytorch/llm/llm_infer.py +++ b/examples/pytorch/llm/llm_infer.py @@ -1,21 +1,32 @@ # ### Setting up experimental environment. -from _common import * + +if __name__ == '__main__': + # Avoid cuda initialization caused by library import (e.g. peft, accelerate) + from _parser import * + # argv = parse_device(['--device', '1']) + argv = parse_device() + +from utils import * @dataclass class InferArguments: - device: str = '0' # e.g. '-1'; '0'; '0,1' model_type: str = field( - default='baichuan-7b', - metadata={ - 'choices': - ['baichuan-7b', 'baichuan-13b', 'chatglm2', 'llama2-7b'] - }) + default='baichuan-7b', metadata={'choices': list(MODEL_MAPPER.keys())}) sft_type: str = field( default='lora', metadata={'choices': ['lora', 'full']}) ckpt_path: str = '/path/to/your/iter_xxx.pth' eval_human: bool = False # False: eval test_dataset - data_sample: Optional[int] = None + ignore_args_error: bool = True # False: notebook compatibility + + dataset: str = field( + default='alpaca-en,alpaca-zh', + metadata={'help': f'dataset choices: {list(DATASET_MAPPER.keys())}'}) + dataset_seed: int = 42 + dataset_sample: Optional[int] = None + dataset_test_size: float = 0.01 + prompt: str = DEFAULT_PROMPT + max_length: Optional[int] = 2048 lora_target_modules: Optional[List[str]] = None lora_rank: int = 8 @@ -29,33 +40,14 @@ class InferArguments: def __post_init__(self): if self.lora_target_modules is None: - if self.model_type in {'baichuan-7b', 'baichuan-13b'}: - self.lora_target_modules = ['W_pack'] - elif self.model_type == 'chatglm2': - self.lora_target_modules = ['query_key_value'] - elif self.model_type == 'llama2-7b': - self.lora_target_modules = ['q_proj', 'k_proj', 'v_proj'] - else: - raise ValueError(f'model_type: {self.model_type}') + self.lora_target_modules = MODEL_MAPPER[self.model_type]['lora_TM'] if not os.path.isfile(self.ckpt_path): raise ValueError( f'Please enter a valid ckpt_path: {self.ckpt_path}') -def parse_args() -> InferArguments: - # return_remaining_strings=True for notebook compatibility - args, remaining_args = HfArgumentParser([ - InferArguments - ]).parse_args_into_dataclasses(return_remaining_strings=True) - logger.info(f'args: {args}') - if len(remaining_args) > 0: - logger.warning(f'remaining_args: {remaining_args}') - return args - - def llm_infer(args: InferArguments) -> None: - select_device(args.device) # ### Loading Model and Tokenizer support_bf16 = torch.cuda.is_bf16_supported() if not support_bf16: @@ -72,7 +64,7 @@ def llm_infer(args: InferArguments) -> None: lora_dropout=args.lora_dropout_p, pretrained_weights=args.ckpt_path) logger.info(f'lora_config: {lora_config}') - Swift.prepare_model(model, lora_config) + model = Swift.prepare_model(model, lora_config) elif args.sft_type == 'full': state_dict = torch.load(args.ckpt_path, map_location='cpu') model.load_state_dict(state_dict) @@ -80,6 +72,11 @@ def llm_infer(args: InferArguments) -> None: raise ValueError(f'args.sft_type: {args.sft_type}') # ### Inference + tokenize_func = partial( + tokenize_function, + tokenizer=tokenizer, + prompt=args.prompt, + max_length=args.max_length) streamer = TextStreamer( tokenizer, skip_prompt=True, skip_special_tokens=True) generation_config = GenerationConfig( @@ -94,17 +91,22 @@ def llm_infer(args: InferArguments) -> None: if args.eval_human: while True: instruction = input('<<< ') - data = {'instruction': instruction, 'input': None, 'output': None} - inference(data, model, tokenizer, streamer, generation_config) + data = {'instruction': instruction} + input_ids = tokenize_func(data)['input_ids'] + inference(input_ids, model, tokenizer, streamer, generation_config) print('-' * 80) else: - _, test_dataset = get_alpaca_en_zh_dataset( - None, True, split_seed=42, data_sample=args.data_sample) + dataset = get_dataset(args.dataset) + _, test_dataset = process_dataset(dataset, args.dataset_test_size, + args.dataset_sample, + args.dataset_seed) mini_test_dataset = test_dataset.select(range(10)) + del dataset for data in mini_test_dataset: output = data['output'] data['output'] = None - inference(data, model, tokenizer, streamer, generation_config) + input_ids = tokenize_func(data)['input_ids'] + inference(input_ids, model, tokenizer, streamer, generation_config) print() print(f'[LABELS]{output}') print('-' * 80) @@ -112,5 +114,10 @@ def llm_infer(args: InferArguments) -> None: if __name__ == '__main__': - args = parse_args() + args, remaining_argv = parse_args(InferArguments, argv) + if len(remaining_argv) > 0: + if args.ignore_args_error: + logger.warning(f'remaining_argv: {remaining_argv}') + else: + raise ValueError(f'remaining_argv: {remaining_argv}') llm_infer(args) diff --git a/examples/pytorch/llm/llm_sft.py b/examples/pytorch/llm/llm_sft.py index 07f1fd5e..a7dabf77 100644 --- a/examples/pytorch/llm/llm_sft.py +++ b/examples/pytorch/llm/llm_sft.py @@ -1,37 +1,45 @@ # ### Setting up experimental environment. """ -pip install numpy pandas matplotlib scikit-learn -pip install transformers datasets -conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia -pip install tqdm tensorboard torchmetrics sentencepiece charset_normalizer -pip install accelerate transformers_stream_generator - # Install the latest version of modelscope from source git clone https://github.com/modelscope/modelscope.git cd modelscope pip install . -# Resolve torchmetrics dependencies and update numpy -pip install numpy -U +conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia +pip install numpy pandas -U # Resolve torchmetrics dependencies and update numpy +pip install matplotlib scikit-learn -U +pip install transformers datasets -U +pip install tqdm tensorboard torchmetrics sentencepiece charset_normalizer -U +pip install accelerate transformers_stream_generator -U """ -from _common import * +if __name__ == '__main__': + # Avoid cuda initialization caused by library import (e.g. peft, accelerate) + from _parser import * + # argv = parse_device(['--device', '1']) + argv = parse_device() + +from utils import * @dataclass class SftArguments: - device: str = '0,1' # e.g. '-1'; '0'; '0,1' seed: int = 42 model_type: str = field( - default='baichuan-7b', - metadata={ - 'choices': - ['baichuan-7b', 'baichuan-13b', 'chatglm2', 'llama2-7b'] - }) + default='baichuan-7b', metadata={'choices': list(MODEL_MAPPER.keys())}) # baichuan-7b: 'lora': 16G; 'full': 80G sft_type: str = field( default='lora', metadata={'choices': ['lora', 'full']}) - data_sample: Optional[int] = None + ignore_args_error: bool = True # False: notebook compatibility + + dataset: str = field( + default='alpaca-en,alpaca-zh', + metadata={'help': f'dataset choices: {list(DATASET_MAPPER.keys())}'}) + dataset_seed: int = 42 + dataset_sample: Optional[int] = None + dataset_test_size: float = 0.01 + prompt: str = DEFAULT_PROMPT + max_length: Optional[int] = 2048 lora_target_modules: Optional[List[str]] = None lora_rank: int = 8 @@ -75,29 +83,10 @@ class SftArguments: raise ValueError(f'sft_type: {self.sft_type}') if self.lora_target_modules is None: - if self.model_type in {'baichuan-7b', 'baichuan-13b'}: - self.lora_target_modules = ['W_pack'] - elif self.model_type == 'chatglm2': - self.lora_target_modules = ['query_key_value'] - elif self.model_type == 'llama2-7b': - self.lora_target_modules = ['q_proj', 'k_proj', 'v_proj'] - else: - raise ValueError(f'model_type: {self.model_type}') - - -def parse_args() -> SftArguments: - # return_remaining_strings=True for notebook compatibility - args, remaining_args = HfArgumentParser([ - SftArguments - ]).parse_args_into_dataclasses(return_remaining_strings=True) - logger.info(f'args: {args}') - if len(remaining_args) > 0: - logger.warning(f'remaining_args: {remaining_args}') - return args + self.lora_target_modules = MODEL_MAPPER[self.model_type]['lora_TM'] def llm_sft(args: SftArguments) -> None: - select_device(args.device) seed_everything(args.seed) # ### Loading Model and Tokenizer @@ -123,18 +112,28 @@ def llm_sft(args: SftArguments) -> None: lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout_p) logger.info(f'lora_config: {lora_config}') - Swift.prepare_model(model, lora_config) + model = Swift.prepare_model(model, lora_config) show_freeze_layers(model) print_model_info(model) # check the device and dtype of the model - _p: Parameter = list(model.parameters())[-1] + _p: Tensor = list(model.parameters())[-1] logger.info(f'device: {_p.device}, dtype: {_p.dtype}') # ### Loading Dataset - tokenize_func = partial(tokenize_function, tokenizer=tokenizer) - train_dataset, val_dataset = get_alpaca_en_zh_dataset( - tokenize_func, split_seed=42, data_sample=args.data_sample) + dataset = get_dataset(args.dataset) + train_dataset, val_dataset = process_dataset(dataset, + args.dataset_test_size, + args.dataset_sample, + args.dataset_seed) + tokenize_func = partial( + tokenize_function, + tokenizer=tokenizer, + prompt=args.prompt, + max_length=args.max_length) + train_dataset = train_dataset.map(tokenize_func) + val_dataset = val_dataset.map(tokenize_func) + del dataset # Data analysis stat_dataset(train_dataset) stat_dataset(val_dataset) @@ -239,11 +238,6 @@ def llm_sft(args: SftArguments) -> None: cfg.update(config) return cfg - device_kwargs = {} - if torch.cuda.device_count() > 1: - # No placement for model, leave the model to `device_map` - device_kwargs['device'] = 'cpu' - trainer = EpochBasedTrainer( model=model, cfg_file=cfg_file, @@ -253,7 +247,6 @@ def llm_sft(args: SftArguments) -> None: remove_unused_data=True, seed=42, cfg_modify_fn=cfg_modify_fn, - **device_kwargs, ) trainer.train() @@ -264,5 +257,10 @@ def llm_sft(args: SftArguments) -> None: if __name__ == '__main__': - args = parse_args() + args, remaining_argv = parse_args(SftArguments, argv) + if len(remaining_argv) > 0: + if args.ignore_args_error: + logger.warning(f'remaining_argv: {remaining_argv}') + else: + raise ValueError(f'remaining_argv: {remaining_argv}') llm_sft(args) diff --git a/examples/pytorch/llm/run_infer.sh b/examples/pytorch/llm/run_infer.sh index efe48958..aa1a1a04 100644 --- a/examples/pytorch/llm/run_infer.sh +++ b/examples/pytorch/llm/run_infer.sh @@ -1,5 +1,7 @@ +#!/bin/bash + python llm_infer.py \ - --device 0 \ - --model_type llama2-7b \ - --ckpt_path "runs/llama2-7b/vx_xxx/output_best/pytorch_model.bin" \ + --device 0,1 \ + --model_type openbuddy-llama2-13b \ + --ckpt_path "runs/openbuddy-llama2-13b/vx_xxx/output_best/pytorch_model.bin" \ --eval_human true diff --git a/examples/pytorch/llm/run_sft.sh b/examples/pytorch/llm/run_sft.sh index 98ae2460..3a6d9ff4 100644 --- a/examples/pytorch/llm/run_sft.sh +++ b/examples/pytorch/llm/run_sft.sh @@ -2,7 +2,8 @@ DATE=$(date +"%Y%m%d-%H%M%S") nohup python llm_sft.py \ - --device 0 \ - --model_type llama2-7b \ - --data_sample 25000 \ + --device 0,1 \ + --model_type openbuddy-llama2-13b \ + --dataset alpaca-en,alpaca-zh \ + --dataset_sample 20000 \ &> train_$DATE.out & diff --git a/examples/pytorch/llm/utils/__init__.py b/examples/pytorch/llm/utils/__init__.py new file mode 100644 index 00000000..e4772c03 --- /dev/null +++ b/examples/pytorch/llm/utils/__init__.py @@ -0,0 +1,5 @@ +from _parser import * + +from .dataset import * +from .models import * +from .utils import * diff --git a/examples/pytorch/llm/utils/dataset.py b/examples/pytorch/llm/utils/dataset.py new file mode 100644 index 00000000..3035ba78 --- /dev/null +++ b/examples/pytorch/llm/utils/dataset.py @@ -0,0 +1,72 @@ +from typing import Optional, Tuple + +import numpy as np +from datasets import Dataset as HfDataset +from datasets import concatenate_datasets +from numpy.random import RandomState + +from modelscope import MsDataset + + +def _processing_alpaca(dataset: HfDataset) -> HfDataset: + instruction = dataset['instruction'] + input_ = dataset['input'] + res = [] + for inst, inp in zip(instruction, input_): + if inp is not None and inp != '': + if inp.startswith('输入:'): + inp = inp[3:] + inst = f'{inst}\n{inp}' + res.append(inst) + dataset = HfDataset.from_dict({ + 'instruction': res, + 'output': dataset['output'] + }) + return dataset + + +def get_alpaca_en_dataset() -> HfDataset: + dataset_en: HfDataset = MsDataset.load( + 'AI-ModelScope/alpaca-gpt4-data-en', split='train').to_hf_dataset() + dataset_en = dataset_en.remove_columns(['text']) + return _processing_alpaca(dataset_en) + + +def get_alpaca_zh_dataset() -> HfDataset: + dataset_zh: HfDataset = MsDataset.load( + 'AI-ModelScope/alpaca-gpt4-data-zh', split='train').to_hf_dataset() + return _processing_alpaca(dataset_zh) + + +def get_seed(random_state: RandomState) -> int: + seed_max = np.iinfo(np.int32).max + seed = random_state.randint(0, seed_max) + return seed + + +def process_dataset(dataset: HfDataset, dataset_test_size: float, + dataset_sample: Optional[int], + dataset_seed: int) -> Tuple[HfDataset, HfDataset]: + random_state = np.random.RandomState(dataset_seed) + if dataset_sample is not None: + index = random_state.permutation(len(dataset))[:dataset_sample] + dataset = dataset.select(index) + dataset = dataset.train_test_split( + dataset_test_size, seed=get_seed(random_state)) + return dataset['train'], dataset['test'] + + +DATASET_MAPPER = { + 'alpaca-en': get_alpaca_en_dataset, + 'alpaca-zh': get_alpaca_zh_dataset, +} + + +def get_dataset(dataset_names: str) -> HfDataset: + dataset_name_list = dataset_names.split(',') + dataset_list = [] + for dataset_name in dataset_name_list: + get_function = DATASET_MAPPER[dataset_name] + dataset_list.append(get_function()) + dataset = concatenate_datasets(dataset_list) + return dataset diff --git a/examples/pytorch/llm/utils/models.py b/examples/pytorch/llm/utils/models.py new file mode 100644 index 00000000..c95df561 --- /dev/null +++ b/examples/pytorch/llm/utils/models.py @@ -0,0 +1,133 @@ +from typing import NamedTuple + +import torch +from torch import dtype as Dtype + +from modelscope import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, Model, + get_logger, read_config, snapshot_download) +from modelscope.models.nlp.chatglm2 import ChatGLM2Config, ChatGLM2Tokenizer + +logger = get_logger() + + +def _add_special_token(tokenizer): + if tokenizer.eos_token_id is None: + tokenizer.eos_token_id = 2 + if tokenizer.bos_token_id is None: + tokenizer.bos_token_id = 1 + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = 0 + logger.info(f'bos_token_id: {tokenizer.bos_token_id}, ' + f'eos_token_id: {tokenizer.eos_token_id}, ' + f'pad_token_id: {tokenizer.pad_token_id}') + + +def get_model_tokenizer_default(model_dir: str, + load_model: bool = True, + add_special_token: bool = True, + torch_dtype: Dtype = torch.float16): + """load from an independent repository""" + model_config = AutoConfig.from_pretrained( + model_dir, trust_remote_code=True) + model_config.torch_dtype = torch_dtype + logger.info(f'model_config: {model_config}') + tokenizer = AutoTokenizer.from_pretrained( + model_dir, trust_remote_code=True) + model = None + if load_model: + model = AutoModelForCausalLM.from_pretrained( + model_dir, + config=model_config, + device_map='auto', + torch_dtype=torch_dtype, + trust_remote_code=True) + + if add_special_token: + _add_special_token(tokenizer) + return model, tokenizer + + +def get_model_tokenizer_chatglm2(model_dir: str, + load_model: bool = True, + add_special_token: bool = True, + torch_dtype: Dtype = torch.float16): + """load from ms library""" + config = read_config(model_dir) + logger.info(config) + model_config = ChatGLM2Config.from_pretrained(model_dir) + model_config.torch_dtype = torch_dtype + logger.info(model_config) + tokenizer = ChatGLM2Tokenizer.from_pretrained(model_dir) + model = None + if load_model: + model = Model.from_pretrained( + model_dir, + cfg_dict=config, + config=model_config, + device_map='auto', + torch_dtype=torch_dtype) + if add_special_token: + _add_special_token(tokenizer) + return model, tokenizer + + +class LoRATM(NamedTuple): + # default lora target modules + baichuan = ['W_pack'] + chatglm2 = ['query_key_value'] + llama2 = ['q_proj', 'k_proj', 'v_proj'] + + +# Reference: 'https://modelscope.cn/models/{model_id}/summary' +MODEL_MAPPER = { + 'baichuan-7b': { + 'model_id': 'baichuan-inc/baichuan-7B', + 'revision': 'v1.0.7', + 'lora_TM': LoRATM.baichuan + }, + 'baichuan-13b': { + 'model_id': 'baichuan-inc/Baichuan-13B-Base', + 'revision': 'v1.0.3', + 'lora_TM': LoRATM.baichuan + }, + 'chatglm2': { + 'model_id': 'ZhipuAI/chatglm2-6b', + 'revision': 'v1.0.6', + 'get_function': get_model_tokenizer_chatglm2, + 'lora_TM': LoRATM.chatglm2 + }, + 'llama2-7b': { + 'model_id': 'modelscope/Llama-2-7b-ms', + 'revision': 'v1.0.2', + 'ignore_file_pattern': [r'.+\.bin$'], # use safetensors + 'lora_TM': LoRATM.llama2 + }, + 'llama2-13b': { + 'model_id': 'modelscope/Llama-2-13b-ms', + 'revision': 'v1.0.2', + 'ignore_file_pattern': [r'.+\.bin$'], + 'lora_TM': LoRATM.llama2 + }, + 'openbuddy-llama2-13b': { + 'model_id': 'OpenBuddy/openbuddy-llama2-13b-v8.1-fp16', + 'lora_TM': LoRATM.llama2 + } +} + + +def get_model_tokenizer(model_type: str, + load_model: bool = True, + add_special_token: bool = True, + torch_dtype: Dtype = torch.float16): + data = MODEL_MAPPER.get(model_type) + if data is None: + raise ValueError(f'model_type: {model_type}') + model_id = data['model_id'] + revision = data.get('revision', 'master') + get_function = data.get('get_function', get_model_tokenizer_default) + ignore_file_pattern = data.get('ignore_file_pattern', []) + model_dir = snapshot_download( + model_id, revision, ignore_file_pattern=ignore_file_pattern) + model, tokenizer = get_function(model_dir, load_model, add_special_token, + torch_dtype) + return model, tokenizer, model_dir diff --git a/examples/pytorch/llm/_common.py b/examples/pytorch/llm/utils/utils.py similarity index 54% rename from examples/pytorch/llm/_common.py rename to examples/pytorch/llm/utils/utils.py index b8921581..5b8ee163 100644 --- a/examples/pytorch/llm/_common.py +++ b/examples/pytorch/llm/utils/utils.py @@ -9,16 +9,10 @@ from functools import partial from types import MethodType from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import json import matplotlib.pyplot as plt import numpy as np import torch -import torch.nn as nn -import torch.optim as optim from datasets import Dataset as HfDataset -from datasets import concatenate_datasets -from matplotlib.axes import Axes -from matplotlib.figure import Figure from numpy import ndarray from tensorboard.backend.event_processing.event_accumulator import \ EventAccumulator @@ -26,23 +20,14 @@ from torch import Tensor from torch import device as Device from torch import dtype as Dtype from torch.nn import Module -from torch.nn.parameter import Parameter from torch.nn.utils.rnn import pad_sequence -from torch.optim import Optimizer -from torch.optim import lr_scheduler as lrs -from torch.optim.lr_scheduler import _LRScheduler as LRScheduler -from torch.utils.data import Dataset from torchmetrics import Accuracy, MeanMetric from tqdm import tqdm -from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, - GenerationConfig, HfArgumentParser, TextStreamer) +from transformers import GenerationConfig, TextStreamer -from modelscope import (Model, MsDataset, get_logger, read_config, - snapshot_download) +from modelscope import get_logger from modelscope.metrics.base import Metric from modelscope.metrics.builder import METRICS -from modelscope.models.nlp.chatglm2 import ChatGLM2Config, ChatGLM2Tokenizer -from modelscope.models.nlp.llama2 import Llama2Config, Llama2Tokenizer from modelscope.swift import LoRAConfig, Swift from modelscope.trainers import EpochBasedTrainer from modelscope.utils.config import Config, ConfigDict @@ -50,7 +35,7 @@ from modelscope.utils.registry import default_group COLOR, COLOR_S = '#FFE2D9', '#FF7043' -PROMPT = """Here's a conversation between a human and an AI assistant. \ +DEFAULT_PROMPT = """Here's a conversation between a human and an AI assistant. \ The AI assistant provides detailed, friendly answers for the human. ### Human: @@ -89,41 +74,6 @@ def get_work_dir(work_dir: str) -> str: return work_dir -def _format_device(device: Union[List[int], str]) -> Tuple[List[int], str]: - if isinstance(device, list): - device_ids = device - device_str = ','.join([str(d) for d in device]) - else: - device_ids = [int(d) for d in device.split(',') if d != '-1'] - device_str = device - device_str = device_str.replace(' ', '') - return device_ids, device_str - - -def select_device(device: Union[List[int], str]) -> Device: - """Call this function before cuda is initialized. - device: e.g. []: 'cpu', [0], [0, 1, 2] - e.g. '-1': 'cpu', '0', '0,1,2' - """ - if torch.cuda.is_initialized(): - logger.warning('CUDA has been initialized! Device selection fails!') - return torch.device('cuda:0') - - device_ids, device_str = _format_device(device) - os.environ['CUDA_VISIBLE_DEVICES'] = device_str - log_s = 'Using device: ' - if len(device_ids) == 0: - master_device: str = 'cpu' - log_s += 'cpu' - else: - assert torch.cuda.is_available( - ) and torch.cuda.device_count() >= len(device_ids) - master_device = 'cuda:0' - log_s += f'cuda:{device_str}' - logger.info(log_s) - return torch.device(master_device) - - def seed_everything(seed: Optional[int] = None, gpu_dtm: bool = False) -> int: if seed is None: seed_max = np.iinfo(np.int32).max @@ -154,16 +104,11 @@ def get_T_max(dataset_len: int, batch_size: int, max_epochs: int, def tokenize_function(example: Dict[str, Optional[str]], tokenizer, + prompt: str = DEFAULT_PROMPT, max_length: Optional[int] = 2048) -> Dict[str, Any]: instruction: str = example['instruction'] - input_ = example['input'] - if input_ is not None and input_ != '': - if input_.startswith('输入:'): - instruction = instruction + input_[3:] - else: - instruction = instruction + input_ - output = example['output'] - src_text = PROMPT.format(instruction=instruction) + output = example.get('output') + src_text = prompt.format(instruction=instruction) src_input_ids: List[int] = tokenizer( src_text, return_attention_mask=False, add_special_tokens=True)['input_ids'] @@ -271,7 +216,7 @@ class MyMetric(Metric): def add(self, outputs: Dict[str, Any], inputs: Dict[str, Any]) -> None: loss: Tensor = outputs.loss - self.loss.update(loss) + self.loss.update(loss.cpu()) labels: Tensor = inputs['labels'] labels = labels[:, 1:] @@ -280,7 +225,7 @@ class MyMetric(Metric): logits = logits[labels_mask].contiguous().view(-1, logits.shape[-1]) pred = logits.argmax(dim=-1) labels = labels[labels_mask].to(logits.device) - self.acc.update(pred, labels) + self.acc.update(pred.cpu(), labels.cpu()) def evaluate(self): return { @@ -293,148 +238,6 @@ class MyMetric(Metric): raise NotImplementedError -def _add_special_token(tokenizer): - if tokenizer.eos_token_id is None: - tokenizer.eos_token_id = 2 - if tokenizer.bos_token_id is None: - tokenizer.bos_token_id = 1 - if tokenizer.pad_token_id is None: - tokenizer.pad_token_id = 0 - logger.info(f'bos_token_id: {tokenizer.bos_token_id}, ' - f'eos_token_id: {tokenizer.eos_token_id}, ' - f'pad_token_id: {tokenizer.pad_token_id}') - - -def get_baichuan_model_tokenizer(model_dir: str, - load_model: bool = True, - add_special_token: bool = True, - torch_dtype: Dtype = torch.float16): - model_config = AutoConfig.from_pretrained( - model_dir, trust_remote_code=True) - model_config.torch_dtype = torch_dtype - logger.info(f'model_config: {model_config}') - tokenizer = AutoTokenizer.from_pretrained( - model_dir, trust_remote_code=True) - model = None - if load_model: - model = AutoModelForCausalLM.from_pretrained( - model_dir, - config=model_config, - device_map='auto', - torch_dtype=torch_dtype, - trust_remote_code=True) - - if add_special_token: - _add_special_token(tokenizer) - return model, tokenizer - - -def get_chatglm2_model_tokenizer(model_dir: str, - load_model: bool = True, - add_special_token: bool = True, - torch_dtype: Dtype = torch.float16): - config = read_config(model_dir) - logger.info(config) - model_config = ChatGLM2Config.from_pretrained(model_dir) - model_config.torch_dtype = torch_dtype - logger.info(model_config) - tokenizer = ChatGLM2Tokenizer.from_pretrained(model_dir) - model = None - if load_model: - model = Model.from_pretrained( - model_dir, - cfg_dict=config, - config=model_config, - device_map='auto', - torch_dtype=torch_dtype) - if add_special_token: - _add_special_token(tokenizer) - return model, tokenizer - - -def get_llama2_model_tokenizer(model_dir: str, - load_model: bool = True, - add_special_token: bool = True, - torch_dtype: Dtype = torch.float16): - config = read_config(model_dir) - logger.info(config) - model_config = Llama2Config.from_pretrained(model_dir) - model_config.torch_dtype = torch_dtype - logger.info(model_config) - tokenizer = Llama2Tokenizer.from_pretrained(model_dir) - model = None - if load_model: - model = Model.from_pretrained( - model_dir, - cfg_dict=config, - config=model_config, - device_map='auto', - torch_dtype=torch_dtype) - if add_special_token: - _add_special_token(tokenizer) - return model, tokenizer - - -def get_model_tokenizer(model_type: str, - load_model: bool = True, - add_special_token: bool = True, - torch_dtype: Dtype = torch.float16): - # ### Loading Model and Tokenizer - if model_type == 'baichuan-7b': - model_dir = snapshot_download('baichuan-inc/baichuan-7B', 'v1.0.7') - model, tokenizer = get_baichuan_model_tokenizer( - model_dir, load_model, add_special_token, torch_dtype) - elif model_type == 'baichuan-13b': - model_dir = snapshot_download('baichuan-inc/Baichuan-13B-Base', - 'v1.0.3') - model, tokenizer = get_baichuan_model_tokenizer( - model_dir, load_model, add_special_token, torch_dtype) - elif model_type == 'chatglm2': - model_dir = snapshot_download('ZhipuAI/chatglm2-6b', 'v1.0.6') - model, tokenizer = get_chatglm2_model_tokenizer( - model_dir, load_model, add_special_token, torch_dtype) - elif model_type == 'llama2-7b': - # use `.safetensors` - model_dir = snapshot_download( - 'modelscope/Llama-2-7b-ms', - 'v1.0.2', - ignore_file_pattern=[r'.+\.bin$']) - model, tokenizer = get_llama2_model_tokenizer(model_dir, load_model, - add_special_token, - torch_dtype) - else: - raise ValueError(f'model_type: {model_type}') - return model, tokenizer, model_dir - - -def get_alpaca_en_zh_dataset( - tokenize_function, - only_val: bool = False, - test_split_p: float = 0.01, - split_seed: int = 42, - data_sample: Optional[int] = None) -> Tuple[HfDataset, HfDataset]: - dataset_en: HfDataset = MsDataset.load( - 'AI-ModelScope/alpaca-gpt4-data-en', split='train').to_hf_dataset() - dataset_zh: HfDataset = MsDataset.load( - 'AI-ModelScope/alpaca-gpt4-data-zh', split='train').to_hf_dataset() - dataset_en = dataset_en.remove_columns(['text']) - dataset: HfDataset = concatenate_datasets([dataset_zh, dataset_en]) - - if data_sample is not None: - dataset = dataset.select(range(data_sample)) - dataset = dataset.train_test_split(test_split_p, seed=split_seed) - if only_val: - dataset = dataset['test'] - if tokenize_function is not None: - dataset = dataset.map(tokenize_function) - dataset = dataset.remove_columns(['instruction', 'input', 'output']) - - if only_val: - return None, dataset - else: - return dataset['train'], dataset['test'] - - Item = Dict[str, float] @@ -500,13 +303,12 @@ def plot_images(tb_dir: str, plt.savefig(fpath, dpi=dpi, bbox_inches='tight') -def inference(data: Dict[str, Optional[str]], +def inference(input_ids: List[int], model, tokenizer, streamer: Optional[TextStreamer] = None, generation_config: Optional[GenerationConfig] = None, tag: str = '[INFERENCE]') -> str: - input_ids = tokenize_function(data, tokenizer)['input_ids'] print(f'{tag}{tokenizer.decode(input_ids)}', end='') input_ids = torch.tensor(input_ids)[None].cuda() attention_mask = torch.ones_like(input_ids)