2023-07-19 17:34:27 +08:00
|
|
|
# ### Setting up experimental environment.
|
2023-08-02 09:25:21 +08:00
|
|
|
import os
|
|
|
|
|
# os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
|
2023-08-11 14:55:24 +08:00
|
|
|
import warnings
|
2023-08-02 09:25:21 +08:00
|
|
|
from dataclasses import dataclass, field
|
|
|
|
|
from functools import partial
|
|
|
|
|
from typing import List, Optional
|
2023-07-26 18:12:55 +08:00
|
|
|
|
2023-08-02 09:25:21 +08:00
|
|
|
import torch
|
2023-08-29 17:27:18 +08:00
|
|
|
from swift import LoRAConfig, Swift
|
2023-08-02 09:25:21 +08:00
|
|
|
from transformers import GenerationConfig, TextStreamer
|
2023-08-11 14:55:24 +08:00
|
|
|
from utils import (DATASET_MAPPING, DEFAULT_PROMPT, MODEL_MAPPING, get_dataset,
|
2023-08-02 09:25:21 +08:00
|
|
|
get_model_tokenizer, inference, parse_args, process_dataset,
|
|
|
|
|
tokenize_function)
|
|
|
|
|
|
|
|
|
|
from modelscope import get_logger
|
2023-07-26 18:12:55 +08:00
|
|
|
|
2023-08-11 14:55:24 +08:00
|
|
|
warnings.warn(
|
|
|
|
|
'This directory has been migrated to '
|
|
|
|
|
'https://github.com/modelscope/swift/tree/main/examples/pytorch/llm, '
|
|
|
|
|
'and the files in this directory are no longer maintained.',
|
|
|
|
|
DeprecationWarning)
|
|
|
|
|
|
2023-08-02 09:25:21 +08:00
|
|
|
logger = get_logger()
|
2023-07-19 17:34:27 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
2023-07-24 15:52:09 +08:00
|
|
|
class InferArguments:
|
2023-07-19 17:34:27 +08:00
|
|
|
model_type: str = field(
|
2023-08-11 14:55:24 +08:00
|
|
|
default='qwen-7b', metadata={'choices': list(MODEL_MAPPING.keys())})
|
2023-07-24 15:52:09 +08:00
|
|
|
sft_type: str = field(
|
|
|
|
|
default='lora', metadata={'choices': ['lora', 'full']})
|
2023-07-25 19:27:11 +08:00
|
|
|
ckpt_path: str = '/path/to/your/iter_xxx.pth'
|
2023-07-19 17:34:27 +08:00
|
|
|
eval_human: bool = False # False: eval test_dataset
|
2023-07-29 00:06:27 +08:00
|
|
|
ignore_args_error: bool = False # True: notebook compatibility
|
2023-07-26 18:12:55 +08:00
|
|
|
|
|
|
|
|
dataset: str = field(
|
|
|
|
|
default='alpaca-en,alpaca-zh',
|
2023-08-11 14:55:24 +08:00
|
|
|
metadata={'help': f'dataset choices: {list(DATASET_MAPPING.keys())}'})
|
2023-07-26 18:12:55 +08:00
|
|
|
dataset_seed: int = 42
|
2023-08-11 14:55:24 +08:00
|
|
|
dataset_sample: int = 20000 # -1: all dataset
|
2023-07-26 18:12:55 +08:00
|
|
|
dataset_test_size: float = 0.01
|
|
|
|
|
prompt: str = DEFAULT_PROMPT
|
|
|
|
|
max_length: Optional[int] = 2048
|
2023-07-25 19:27:11 +08:00
|
|
|
|
2023-07-19 17:34:27 +08:00
|
|
|
lora_target_modules: Optional[List[str]] = None
|
|
|
|
|
lora_rank: int = 8
|
|
|
|
|
lora_alpha: int = 32
|
|
|
|
|
lora_dropout_p: float = 0.1
|
2023-07-24 15:52:09 +08:00
|
|
|
|
2023-07-19 17:34:27 +08:00
|
|
|
max_new_tokens: int = 512
|
|
|
|
|
temperature: float = 0.9
|
|
|
|
|
top_k: int = 50
|
|
|
|
|
top_p: float = 0.9
|
|
|
|
|
|
|
|
|
|
def __post_init__(self):
|
|
|
|
|
if self.lora_target_modules is None:
|
2023-08-11 14:55:24 +08:00
|
|
|
self.lora_target_modules = MODEL_MAPPING[
|
|
|
|
|
self.model_type]['lora_TM']
|
2023-07-24 15:52:09 +08:00
|
|
|
|
2023-07-25 19:27:11 +08:00
|
|
|
if not os.path.isfile(self.ckpt_path):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f'Please enter a valid ckpt_path: {self.ckpt_path}')
|
2023-07-19 17:34:27 +08:00
|
|
|
|
|
|
|
|
|
2023-07-24 15:52:09 +08:00
|
|
|
def llm_infer(args: InferArguments) -> None:
|
|
|
|
|
# ### Loading Model and Tokenizer
|
|
|
|
|
support_bf16 = torch.cuda.is_bf16_supported()
|
|
|
|
|
if not support_bf16:
|
|
|
|
|
logger.warning(f'support_bf16: {support_bf16}')
|
2023-08-11 14:55:24 +08:00
|
|
|
|
|
|
|
|
kwargs = {'low_cpu_mem_usage': True, 'device_map': 'auto'}
|
2023-07-24 15:52:09 +08:00
|
|
|
model, tokenizer, _ = get_model_tokenizer(
|
2023-08-11 14:55:24 +08:00
|
|
|
args.model_type, torch_dtype=torch.bfloat16, **kwargs)
|
2023-07-19 17:34:27 +08:00
|
|
|
|
2023-07-24 15:52:09 +08:00
|
|
|
# ### Preparing lora
|
|
|
|
|
if args.sft_type == 'lora':
|
|
|
|
|
lora_config = LoRAConfig(
|
2023-08-29 17:27:18 +08:00
|
|
|
target_modules=args.lora_target_modules,
|
|
|
|
|
r=args.lora_rank,
|
2023-07-24 15:52:09 +08:00
|
|
|
lora_alpha=args.lora_alpha,
|
|
|
|
|
lora_dropout=args.lora_dropout_p,
|
2023-07-25 19:27:11 +08:00
|
|
|
pretrained_weights=args.ckpt_path)
|
2023-07-24 15:52:09 +08:00
|
|
|
logger.info(f'lora_config: {lora_config}')
|
2023-07-26 18:12:55 +08:00
|
|
|
model = Swift.prepare_model(model, lora_config)
|
2023-08-29 17:27:18 +08:00
|
|
|
state_dict = torch.load(args.ckpt_path, map_location='cpu')
|
|
|
|
|
model.load_state_dict(state_dict)
|
2023-07-24 15:52:09 +08:00
|
|
|
elif args.sft_type == 'full':
|
2023-07-25 19:27:11 +08:00
|
|
|
state_dict = torch.load(args.ckpt_path, map_location='cpu')
|
2023-07-24 15:52:09 +08:00
|
|
|
model.load_state_dict(state_dict)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f'args.sft_type: {args.sft_type}')
|
2023-07-19 17:34:27 +08:00
|
|
|
|
2023-07-24 15:52:09 +08:00
|
|
|
# ### Inference
|
2023-07-26 18:12:55 +08:00
|
|
|
tokenize_func = partial(
|
|
|
|
|
tokenize_function,
|
|
|
|
|
tokenizer=tokenizer,
|
|
|
|
|
prompt=args.prompt,
|
|
|
|
|
max_length=args.max_length)
|
2023-07-24 15:52:09 +08:00
|
|
|
streamer = TextStreamer(
|
|
|
|
|
tokenizer, skip_prompt=True, skip_special_tokens=True)
|
|
|
|
|
generation_config = GenerationConfig(
|
|
|
|
|
max_new_tokens=args.max_new_tokens,
|
|
|
|
|
temperature=args.temperature,
|
|
|
|
|
top_k=args.top_k,
|
|
|
|
|
top_p=args.top_p,
|
|
|
|
|
do_sample=True,
|
2023-08-14 11:45:33 +08:00
|
|
|
pad_token_id=tokenizer.pad_token_id,
|
|
|
|
|
eos_token_id=tokenizer.eos_token_id)
|
2023-07-24 15:52:09 +08:00
|
|
|
logger.info(f'generation_config: {generation_config}')
|
2023-07-19 17:34:27 +08:00
|
|
|
|
2023-07-24 15:52:09 +08:00
|
|
|
if args.eval_human:
|
|
|
|
|
while True:
|
|
|
|
|
instruction = input('<<< ')
|
2023-07-26 18:12:55 +08:00
|
|
|
data = {'instruction': instruction}
|
|
|
|
|
input_ids = tokenize_func(data)['input_ids']
|
|
|
|
|
inference(input_ids, model, tokenizer, streamer, generation_config)
|
2023-07-24 15:52:09 +08:00
|
|
|
print('-' * 80)
|
|
|
|
|
else:
|
2023-07-29 00:06:27 +08:00
|
|
|
dataset = get_dataset(args.dataset.split(','))
|
2023-07-26 18:12:55 +08:00
|
|
|
_, test_dataset = process_dataset(dataset, args.dataset_test_size,
|
|
|
|
|
args.dataset_sample,
|
|
|
|
|
args.dataset_seed)
|
2023-07-24 15:52:09 +08:00
|
|
|
mini_test_dataset = test_dataset.select(range(10))
|
2023-07-26 18:12:55 +08:00
|
|
|
del dataset
|
2023-07-24 15:52:09 +08:00
|
|
|
for data in mini_test_dataset:
|
|
|
|
|
output = data['output']
|
|
|
|
|
data['output'] = None
|
2023-07-26 18:12:55 +08:00
|
|
|
input_ids = tokenize_func(data)['input_ids']
|
|
|
|
|
inference(input_ids, model, tokenizer, streamer, generation_config)
|
2023-07-24 15:52:09 +08:00
|
|
|
print()
|
|
|
|
|
print(f'[LABELS]{output}')
|
|
|
|
|
print('-' * 80)
|
|
|
|
|
# input('next[ENTER]')
|
2023-07-19 17:34:27 +08:00
|
|
|
|
|
|
|
|
|
2023-07-24 15:52:09 +08:00
|
|
|
if __name__ == '__main__':
|
2023-08-02 09:25:21 +08:00
|
|
|
args, remaining_argv = parse_args(InferArguments)
|
2023-07-26 18:12:55 +08:00
|
|
|
if len(remaining_argv) > 0:
|
|
|
|
|
if args.ignore_args_error:
|
|
|
|
|
logger.warning(f'remaining_argv: {remaining_argv}')
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f'remaining_argv: {remaining_argv}')
|
2023-07-24 15:52:09 +08:00
|
|
|
llm_infer(args)
|