diff --git a/data/test b/data/test index acc59489..c117008c 160000 --- a/data/test +++ b/data/test @@ -1 +1 @@ -Subproject commit acc59489d3954fc09cd99d8d4aed818a8d39b283 +Subproject commit c117008caa9dc447c208e9ed6bc11310512d4a3a diff --git a/modelscope/tuners/__init__.py b/examples/pytorch/chatglm6b/__init__.py similarity index 100% rename from modelscope/tuners/__init__.py rename to examples/pytorch/chatglm6b/__init__.py diff --git a/examples/pytorch/chatglm6b/chatglm_trainer.py b/examples/pytorch/chatglm6b/chatglm_trainer.py new file mode 100644 index 00000000..b34563bd --- /dev/null +++ b/examples/pytorch/chatglm6b/chatglm_trainer.py @@ -0,0 +1,118 @@ +from typing import Any, Dict, Union + +import numpy as np +import torch +from transformers.deepspeed import is_deepspeed_zero3_enabled + +from modelscope import EpochBasedTrainer, get_logger + +logger = get_logger(__name__) + + +class Seq2SeqTrainer(EpochBasedTrainer): + + def _decode(self, tokens, ignore_pad_token_for_loss=False): + tokens = tokens.cpu().numpy() + if ignore_pad_token_for_loss: + tokens = np.where(tokens != -100, tokens, + self.tokenizer.pad_token_id) + return [ + t for t in self.tokenizer.batch_decode( + tokens, skip_special_tokens=True) if t != '' + ] + + def evaluation_step( + self, + inputs: Dict[str, Union[torch.Tensor, Any]], + ): + has_labels = 'labels' in inputs + # XXX: adapt synced_gpus for fairscale as well + gen_kwargs = self.cfg['gen_kwargs'] + if gen_kwargs.get('max_length') is None and gen_kwargs.get( + 'max_new_tokens') is None: + gen_kwargs['max_length'] = self.model.config.max_length + gen_kwargs['num_beams'] = ( + gen_kwargs['num_beams'] if gen_kwargs.get('num_beams') is not None + else self.model.config.num_beams) + default_synced_gpus = True if is_deepspeed_zero3_enabled() else False + gen_kwargs['synced_gpus'] = ( + gen_kwargs['synced_gpus'] if gen_kwargs.get('synced_gpus') + is not None else default_synced_gpus) + + if 'attention_mask' in inputs: + gen_kwargs['attention_mask'] = inputs.get('attention_mask', None) + if 'position_ids' in inputs: + gen_kwargs['position_ids'] = inputs.get('position_ids', None) + if 'global_attention_mask' in inputs: + gen_kwargs['global_attention_mask'] = inputs.get( + 'global_attention_mask', None) + + # prepare generation inputs + # some encoder-decoder models can have varying encoder's and thus + # varying model input names + if hasattr( + self.model, 'encoder' + ) and self.model.encoder.main_input_name != self.model.main_input_name: + generation_inputs = inputs[self.model.encoder.main_input_name] + else: + generation_inputs = inputs[self.model.main_input_name] + + gen_kwargs['input_ids'] = generation_inputs + gen_kwargs['pad_token_id'] = self.tokenizer.pad_token_id + generated_tokens = self.model.generate(**gen_kwargs) + generated_tokens = generated_tokens[:, generation_inputs.size()[-1]:] + + # in case the batch is shorter than max length, the output should be padded + if gen_kwargs.get('max_length') is not None and generated_tokens.shape[ + -1] < gen_kwargs['max_length']: + generated_tokens = self._pad_tensors_to_max_len( + generated_tokens, gen_kwargs['max_length']) + elif gen_kwargs.get('max_new_tokens' + ) is not None and generated_tokens.shape[-1] < ( + gen_kwargs['max_new_tokens'] + 1): + generated_tokens = self._pad_tensors_to_max_len( + generated_tokens, gen_kwargs['max_new_tokens'] + 1) + + if has_labels: + labels = inputs['labels'] + if gen_kwargs.get('max_length') is not None and labels.shape[ + -1] < gen_kwargs['max_length']: + labels = self._pad_tensors_to_max_len(labels, + gen_kwargs['max_length']) + elif gen_kwargs.get( + 'max_new_tokens') is not None and labels.shape[-1] < ( + gen_kwargs['max_new_tokens'] + 1): + labels = self._pad_tensors_to_max_len( + labels, (gen_kwargs['max_new_tokens'] + 1)) + else: + labels = None + + generated_tokens = [ + ''.join(self._decode(seq, False)) for seq in generated_tokens + ] + inputs['tgts'] = [''.join(self._decode(seq, True)) for seq in labels] + return { + 'preds': generated_tokens, + } + + def _pad_tensors_to_max_len(self, tensor, max_length): + if self.tokenizer is not None and hasattr(self.tokenizer, + 'pad_token_id'): + # If PAD token is not defined at least EOS token has to be defined + pad_token_id = ( + self.tokenizer.pad_token_id if self.tokenizer.pad_token_id + is not None else self.tokenizer.eos_token_id) + else: + if self.model.config.pad_token_id is not None: + pad_token_id = self.model.config.pad_token_id + else: + raise ValueError( + 'Pad_token_id must be set in the configuration of the model, in order to pad tensors' + ) + + padded_tensor = pad_token_id * torch.ones( + (tensor.shape[0], max_length), + dtype=tensor.dtype, + device=tensor.device) + padded_tensor[:, :tensor.shape[-1]] = tensor + return padded_tensor diff --git a/examples/pytorch/chatglm6b/finetune.py b/examples/pytorch/chatglm6b/finetune.py new file mode 100644 index 00000000..2dc85f2a --- /dev/null +++ b/examples/pytorch/chatglm6b/finetune.py @@ -0,0 +1,380 @@ +import os +from dataclasses import dataclass, field + +import numpy as np +import torch +from chatglm_trainer import Seq2SeqTrainer +from text_generation_metric import TextGenerationMetric +from transformers import DataCollatorForSeq2Seq + +from modelscope import snapshot_download +from modelscope.metainfo import Models +from modelscope.models import Model +from modelscope.msdatasets import MsDataset +from modelscope.swift import Swift +from modelscope.swift.lora import LoRAConfig +from modelscope.trainers.training_args import TrainingArgs +from modelscope.utils.config import ConfigDict +from modelscope.utils.hub import read_config + + +@dataclass(init=False) +class Chatglm6bArguments(TrainingArgs): + ptuning_checkpoint: str = field( + default=None, + metadata={ + 'help': 'The p-tuning checkpoint previously trained.', + }) + + pre_seq_len: int = field( + default=None, metadata={ + 'help': 'The p-tuning sequence length', + }) + + prefix_projection: bool = field( + default=False, metadata={ + 'help': '', + }) + + quantization_bit: int = field( + default=None, metadata={ + 'help': 'Quantized bit', + }) + + prompt_column: str = field( + default=None, + metadata={ + 'help': + 'The name of the column in the datasets containing the full texts (for summarization).' + }, + ) + + response_column: str = field( + default=None, + metadata={ + 'help': + 'The name of the column in the datasets containing the summaries (for summarization).' + }, + ) + + history_column: str = field( + default=None, + metadata={ + 'help': + 'The name of the column in the datasets containing the history of chat.' + }, + ) + + source_prefix: str = field( + default='', + metadata={ + 'help': + 'A prefix to add before every source text (useful for T5 models).' + }) + + ignore_pad_token_for_loss: bool = field( + default=True, + metadata={ + 'help': + 'Whether to ignore the tokens corresponding to padded labels in the loss computation or not.' + }, + ) + + max_source_length: int = field( + default=1024, + metadata={ + 'help': + ('The maximum total input sequence length after tokenization. Sequences longer ' + 'than this will be truncated, sequences shorter will be padded.') + }, + ) + + max_target_length: int = field( + default=128, + metadata={ + 'help': + ('The maximum total sequence length for target text after tokenization. Sequences longer ' + 'than this will be truncated, sequences shorter will be padded.') + }, + ) + + max_train_samples: int = field( + default=None, + metadata={ + 'help': + ('For debugging purposes or quicker training, truncate the number of training examples to this ' + 'value if set.') + }, + ) + + max_eval_samples: int = field( + default=None, + metadata={ + 'help': + ('For debugging purposes or quicker training, truncate the number of evaluation examples to this ' + 'value if set.') + }, + ) + + preprocessing_num_workers: int = field( + default=None, + metadata={ + 'help': 'The number of processes to use for the preprocessing.' + }, + ) + + use_lora: int = field( + default=0, + metadata={'help': 'Whether to use lora to train the model.'}, + ) + + lora_rank: int = field( + default=32, + metadata={'help': 'The lora rank'}, + ) + + lora_alpha: int = field( + default=32, + metadata={'help': 'The lora alpha'}, + ) + + lora_dropout: float = field( + default=0.05, + metadata={'help': 'The lora alpha'}, + ) + + +args = Chatglm6bArguments(eval_metrics='chatglm').parse_cli() +print(args) +config, _ = args.to_config(ignore_default_config=args.use_model_config) +config.dump('./configuration.json') + +if config['model']['type'] == 'chatglm6b': + from modelscope.models.nlp import ChatGLMTokenizer +else: + from modelscope.models.nlp import ChatGLM2Tokenizer as ChatGLMTokenizer + + +def cfg_modify_fn(cfg): + if args.use_model_config: + cfg.merge_from_dict(config) + else: + cfg = config + if cfg.train.lr_scheduler.type == 'LinearLR': + cfg.train.lr_scheduler['total_iters'] = \ + int(len(train_dataset) / cfg.train.dataloader.batch_size_per_gpu) * cfg.train.max_epochs + cfg['gen_kwargs'] = { + 'do_sample': True, + 'top_p': 0.7, + 'max_length': 512, + 'temperature': 0.95 + } + return cfg + + +train_dataset = MsDataset.load( + args.train_dataset_name, + subset_name=args.train_subset_name, + split=args.train_split) +validation_dataset = MsDataset.load( + args.val_dataset_name, + subset_name=args.val_subset_name, + split=args.val_split) + +model_dir = snapshot_download(args.model) +model_config = read_config(model_dir) +model_config['model'] = ConfigDict({ + 'type': config['model']['type'], +}) + +if config['model']['type'] == 'chatglm6b': + model_config['model']['pre_seq_len'] = args.pre_seq_len + model_config['model']['prefix_projection'] = args.prefix_projection + +tokenizer = ChatGLMTokenizer.from_pretrained(model_dir, trust_remote_code=True) +model = Model.from_pretrained(model_dir, cfg_dict=model_config) + +if args.ptuning_checkpoint is not None: + # Evaluation + # Loading extra state dict of prefix encoder + + prefix_state_dict = torch.load( + os.path.join(args.ptuning_checkpoint, 'pytorch_model.bin')) + new_prefix_state_dict = {} + for k, v in prefix_state_dict.items(): + if k.startswith('transformer.prefix_encoder.'): + new_prefix_state_dict[k[len('transformer.prefix_encoder.'):]] = v + model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) + +if args.quantization_bit is not None: + print(f'Quantized to {args.quantization_bit} bit') + model = model.quantize(args.quantization_bit) +if args.pre_seq_len is not None: + # P-tuning v2 + model = model.half() + model.transformer.prefix_encoder.float() +else: + # Finetune + model = model.float() + +if args.use_lora != 0: + lora_config = LoRAConfig( + replace_modules=['attention.query_key_value'], + rank=args.lora_rank, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout) + model = model.bfloat16() + Swift.prepare_model(model, lora_config) + +prefix = args.source_prefix if args.source_prefix is not None else '' + +# Get the column names for input/target. +prompt_column = args.prompt_column +response_column = args.response_column +history_column = args.history_column + +# Temporarily set max_target_length for training. +max_target_length = args.max_target_length + +model_parameters = filter(lambda p: p.requires_grad, model.parameters()) +trainable_params = sum([np.prod(p.size()) for p in model_parameters]) + +model_parameters = filter(lambda p: not p.requires_grad, model.parameters()) +non_trainable_params = sum([np.prod(p.size()) for p in model_parameters]) + +print('trainable_params:{} ({:.2f}%), non_trainable_params:{}'.format( + trainable_params, trainable_params / non_trainable_params * 100, + non_trainable_params)) + + +def preprocess_function_eval(examples): + inputs, targets = [], [] + for i in range(len(examples[prompt_column])): + if examples[prompt_column][i] and examples[response_column][i]: + query = examples[prompt_column][i] + if history_column is None or len(examples[history_column][i]) == 0: + prompt = query + else: + prompt = '' + history = examples[history_column][i] + for turn_idx, (old_query, response) in enumerate(history): + prompt += '[Round {}]\n问:{}\n答:{}\n'.format( + turn_idx, old_query, response) + prompt += '[Round {}]\n问:{}\n答:'.format(len(history), query) + inputs.append(prompt) + targets.append(examples[response_column][i]) + + inputs = [prefix + inp for inp in inputs] + model_inputs = tokenizer( + inputs, + max_length=args.max_source_length, + truncation=True, + padding=True) + labels = tokenizer( + text_target=targets, max_length=max_target_length, truncation=True) + + if args.ignore_pad_token_for_loss: + labels['input_ids'] = [[(lb if lb != tokenizer.pad_token_id else -100) + for lb in label] + for label in labels['input_ids']] + model_inputs['labels'] = labels['input_ids'] + + return model_inputs + + +def preprocess_function_train(examples): + max_seq_length = args.max_source_length + args.max_target_length + + model_inputs = { + 'input_ids': [], + 'labels': [], + } + for i in range(len(examples[prompt_column])): + if examples[prompt_column][i] and examples[response_column][i]: + query, answer = examples[prompt_column][i], examples[ + response_column][i] + + if history_column is None: + prompt = query + else: + prompt = '' + history = examples[history_column][i] + for turn_idx, (old_query, response) in enumerate(history): + prompt += '[Round {}]\n问:{}\n答:{}\n'.format( + turn_idx, old_query, response) + prompt += '[Round {}]\n问:{}\n答:'.format(len(history), query) + + prompt = prefix + prompt + a_ids = tokenizer.encode(text=prompt, add_special_tokens=False) + b_ids = tokenizer.encode(text=answer, add_special_tokens=False) + + if len(a_ids) > args.max_source_length - 1: + a_ids = a_ids[:args.max_source_length - 1] + + if len(b_ids) > args.max_target_length - 2: + b_ids = b_ids[:args.max_target_length - 2] + + input_ids = tokenizer.build_inputs_with_special_tokens( + a_ids, b_ids) + + if config['model']['type'] == 'chatglm6b': + context_length = input_ids.index(tokenizer.bos_token_id) + else: + context_length = len(a_ids) + 2 + mask_position = context_length - 1 + labels = [-100] * context_length + input_ids[mask_position + 1:] + + pad_len = max_seq_length - len(input_ids) + input_ids = input_ids + [tokenizer.pad_token_id] * pad_len + if config['model']['type'] == 'chatglm6b': + labels = labels + [tokenizer.pad_token_id] * pad_len + if args.ignore_pad_token_for_loss: + labels = [(lb if lb != tokenizer.pad_token_id else -100) + for lb in labels] + else: + labels = labels + [-100] * pad_len + + model_inputs['input_ids'].append(input_ids) + model_inputs['labels'].append(labels) + + return model_inputs + + +train_dataset = train_dataset.to_hf_dataset().map( + preprocess_function_train, + batched=True, + num_proc=args.preprocessing_num_workers, + desc='Running tokenizer on train dataset', +) + +validation_dataset = validation_dataset.to_hf_dataset().map( + preprocess_function_eval, + batched=True, + num_proc=args.preprocessing_num_workers, + desc='Running tokenizer on eval dataset', +) + +# Data collator +label_pad_token_id = -100 if args.ignore_pad_token_for_loss else tokenizer.pad_token_id +data_collator = DataCollatorForSeq2Seq( + tokenizer, + model=model, + label_pad_token_id=label_pad_token_id, + pad_to_multiple_of=None, + padding=False) + +model.gradient_checkpointing_enable() +if config['model']['type'] == 'chatglm6b': + model.enable_input_require_grads() + +trainer = Seq2SeqTrainer( + model=model, + cfg_file='./configuration.json', + train_dataset=train_dataset, + eval_dataset=validation_dataset, + seed=args.seed, + data_collator=data_collator, + remove_unused_data=True, + cfg_modify_fn=cfg_modify_fn) +trainer.tokenizer = tokenizer +trainer.train() diff --git a/examples/pytorch/chatglm6b/lora_inference.py b/examples/pytorch/chatglm6b/lora_inference.py new file mode 100644 index 00000000..aa86e890 --- /dev/null +++ b/examples/pytorch/chatglm6b/lora_inference.py @@ -0,0 +1,31 @@ +from modelscope import Model, pipeline, read_config +from modelscope.metainfo import Models +from modelscope.swift import Swift +from modelscope.swift.lora import LoRAConfig +from modelscope.utils.config import ConfigDict + +lora_config = LoRAConfig( + replace_modules=['attention.query_key_value'], + rank=32, + lora_alpha=32, + lora_dropout=0.05, + pretrained_weights='./lora_dureader_target/iter_600.pth') + +model_dir = 'ZhipuAI/chatglm2-6b' +model_config = read_config(model_dir) +model_config['model'] = ConfigDict({ + 'type': Models.chatglm2_6b, +}) + +model = Model.from_pretrained(model_dir, cfg_dict=model_config) +model = model.bfloat16() +Swift.prepare_model(model, lora_config) + +pipe = pipeline('chat', model, pipeline_name='chatglm2_6b-text-generation') + +print( + pipe({ + 'text': + '纵使进入21世纪后,我国教育水平有了明显进步,高考的难度却依旧不容小觑,高考被中国学生和家长定义为改变命运、改写人生脑重要考试,为了这场考试,学生和家长都付出了很多。', + 'history': [] + })) diff --git a/examples/pytorch/chatglm6b/ptuning_inference.py b/examples/pytorch/chatglm6b/ptuning_inference.py new file mode 100644 index 00000000..ab32bec0 --- /dev/null +++ b/examples/pytorch/chatglm6b/ptuning_inference.py @@ -0,0 +1,34 @@ +import torch + +from modelscope import Model, pipeline, read_config +from modelscope.metainfo import Models +from modelscope.utils.config import ConfigDict + +model_dir = 'ZhipuAI/ChatGLM-6B' +model_config = read_config(model_dir) +model_config['model'] = ConfigDict({ + 'type': Models.chatglm_6b, + 'pre_seq_len': 128, + 'prefix_projection': False, +}) + +model = Model.from_pretrained(model_dir, cfg_dict=model_config) +model = model.half() +model.transformer.prefix_encoder.float() +prefix_state_dict = torch.load('./ptuning_dureader_target/iter_900.pth') +new_prefix_state_dict = {} +for k, v in prefix_state_dict.items(): + if k.startswith('transformer.prefix_encoder.'): + new_prefix_state_dict[k[len('transformer.prefix_encoder.'):]] = v +model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) + +pipe = pipeline('chat', model) + +print( + pipe({ + 'text': + '维生素C也叫抗坏血酸,所以它最重要的一个作用是预防坏血病。另外,维生素C在控制感染和愈合伤口方面发挥作用,是一种强大的抗氧化剂,' + '可以中和有害的自由基。维生素C还是合成胶原蛋白的重要营养成分,胶原蛋白是结缔组织中的一种纤维蛋白,它存在于身体的各个系统中:' + '神经系统、免疫系统、骨骼系统、软骨系统、血液系统和其他系统。维生素C有助于产生作用于大脑和神经的多种激素和化学信使。', + 'history': [] + })) diff --git a/examples/pytorch/chatglm6b/run_train_chatglm2_lora_dureader_v2.sh b/examples/pytorch/chatglm6b/run_train_chatglm2_lora_dureader_v2.sh new file mode 100644 index 00000000..d24494cc --- /dev/null +++ b/examples/pytorch/chatglm6b/run_train_chatglm2_lora_dureader_v2.sh @@ -0,0 +1,28 @@ +LR=5e-5 + +PYTHONPATH=. python examples/pytorch/chatglm6b/finetune.py \ + --train_dataset_name modelscope/DuReader_robust-QG \ + --val_dataset_name modelscope/DuReader_robust-QG \ + --train_subset_name default \ + --val_subset_name default \ + --train_split train \ + --val_split validation \ + --prompt_column text1 \ + --response_column text2 \ + --model "ZhipuAI/chatglm2-6b" \ + --max_source_length 64 \ + --max_target_length 64 \ + --per_device_train_batch_size 16 \ + --per_device_eval_batch_size 1 \ + --train.optimizer.options.cumulative_iters 1 \ + --max_epochs 2 \ + --save_strategy 'by_step' \ + --save_interval 300 \ + --lr $LR \ + --eval_strategy "by_step" \ + --eval_interval 300 \ + --lr_strategy 'by_step' \ + --task 'chat' \ + --model.type 'chatglm2-6b' \ + --use_lora 1 \ + --work_dir lora_dureader_target \ diff --git a/examples/pytorch/chatglm6b/run_train_lora_adv.sh b/examples/pytorch/chatglm6b/run_train_lora_adv.sh new file mode 100644 index 00000000..cb6a7856 --- /dev/null +++ b/examples/pytorch/chatglm6b/run_train_lora_adv.sh @@ -0,0 +1,24 @@ +LR=5e-5 + +PYTHONPATH=. python examples/pytorch/chatglm6b/finetune.py \ + --train_dataset_name AdvertiseGen/train.json \ + --val_dataset_name AdvertiseGen/dev.json \ + --prompt_column content \ + --response_column summary \ + --model "ZhipuAI/ChatGLM-6B" \ + --max_source_length 64 \ + --max_target_length 64 \ + --per_device_train_batch_size 16 \ + --per_device_eval_batch_size 1 \ + --train.optimizer.options.cumulative_iters 1 \ + --max_epochs 1 \ + --save_strategy 'by_step' \ + --save_interval 1000 \ + --lr $LR \ + --eval_strategy "by_step" \ + --eval_interval 1000 \ + --lr_strategy 'by_step' \ + --task 'chat' \ + --model.type 'chatglm6b' \ + --use_lora 1 \ + --work_dir lora_adv_target \ diff --git a/examples/pytorch/chatglm6b/run_train_lora_dureader.sh b/examples/pytorch/chatglm6b/run_train_lora_dureader.sh new file mode 100644 index 00000000..26cbce15 --- /dev/null +++ b/examples/pytorch/chatglm6b/run_train_lora_dureader.sh @@ -0,0 +1,28 @@ +LR=5e-5 + +PYTHONPATH=. python examples/pytorch/chatglm6b/finetune.py \ + --train_dataset_name modelscope/DuReader_robust-QG \ + --val_dataset_name modelscope/DuReader_robust-QG \ + --train_subset_name default \ + --val_subset_name default \ + --train_split train \ + --val_split validation \ + --prompt_column text1 \ + --response_column text2 \ + --model "ZhipuAI/ChatGLM-6B" \ + --max_source_length 64 \ + --max_target_length 64 \ + --per_device_train_batch_size 16 \ + --per_device_eval_batch_size 1 \ + --train.optimizer.options.cumulative_iters 1 \ + --max_epochs 2 \ + --save_strategy 'by_step' \ + --save_interval 300 \ + --lr $LR \ + --eval_strategy "by_step" \ + --eval_interval 300 \ + --lr_strategy 'by_step' \ + --task 'chat' \ + --model.type 'chatglm6b' \ + --use_lora 1 \ + --work_dir lora_dureader_target \ diff --git a/examples/pytorch/chatglm6b/run_train_ptuning_adv.sh b/examples/pytorch/chatglm6b/run_train_ptuning_adv.sh new file mode 100644 index 00000000..667c0c96 --- /dev/null +++ b/examples/pytorch/chatglm6b/run_train_ptuning_adv.sh @@ -0,0 +1,26 @@ +PRE_SEQ_LEN=128 +LR=2e-2 + +PYTHONPATH=. python examples/pytorch/chatglm6b/finetune.py \ + --train_dataset_name AdvertiseGen/train.json \ + --val_dataset_name AdvertiseGen/dev.json \ + --prompt_column content \ + --response_column summary \ + --model "ZhipuAI/ChatGLM-6B" \ + --max_source_length 64 \ + --max_target_length 64 \ + --per_device_train_batch_size 16 \ + --per_device_eval_batch_size 1 \ + --train.optimizer.options.cumulative_iters 1 \ + --max_epochs 1 \ + --save_strategy 'by_step' \ + --save_interval 1000 \ + --lr $LR \ + --eval_strategy "by_step" \ + --eval_interval 1000 \ + --lr_strategy 'by_step' \ + --task 'chat' \ + --model.type 'chatglm6b' \ + --pre_seq_len $PRE_SEQ_LEN \ + --quantization_bit 4 \ + --work_dir ptuning_adv_target \ diff --git a/examples/pytorch/chatglm6b/run_train_ptuning_dureader.sh b/examples/pytorch/chatglm6b/run_train_ptuning_dureader.sh new file mode 100644 index 00000000..d36ad50a --- /dev/null +++ b/examples/pytorch/chatglm6b/run_train_ptuning_dureader.sh @@ -0,0 +1,30 @@ +PRE_SEQ_LEN=128 +LR=2e-2 + +PYTHONPATH=. python examples/pytorch/chatglm6b/finetune.py \ + --train_dataset_name modelscope/DuReader_robust-QG \ + --val_dataset_name modelscope/DuReader_robust-QG \ + --train_subset_name default \ + --val_subset_name default \ + --train_split train \ + --val_split validation \ + --prompt_column text1 \ + --response_column text2 \ + --model "ZhipuAI/ChatGLM-6B" \ + --max_source_length 64 \ + --max_target_length 64 \ + --per_device_train_batch_size 16 \ + --per_device_eval_batch_size 1 \ + --train.optimizer.options.cumulative_iters 1 \ + --max_epochs 3 \ + --save_strategy 'by_step' \ + --save_interval 300 \ + --lr $LR \ + --eval_strategy "by_step" \ + --eval_interval 300 \ + --lr_strategy 'by_step' \ + --task 'chat' \ + --model.type 'chatglm6b' \ + --pre_seq_len $PRE_SEQ_LEN \ + --quantization_bit 4 \ + --work_dir ptuning_dureader_target \ diff --git a/examples/pytorch/chatglm6b/text_generation_metric.py b/examples/pytorch/chatglm6b/text_generation_metric.py new file mode 100644 index 00000000..2083453a --- /dev/null +++ b/examples/pytorch/chatglm6b/text_generation_metric.py @@ -0,0 +1,85 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Dict, Iterable, List + +import jieba +import numpy as np +from nltk.translate.bleu_score import (SmoothingFunction, corpus_bleu, + sentence_bleu) +from rouge import Rouge + +from modelscope.metainfo import Metrics +from modelscope.metrics.base import Metric +from modelscope.metrics.builder import METRICS, MetricKeys +from modelscope.utils.chinese_utils import rebuild_chinese_str +from modelscope.utils.registry import default_group + + +@METRICS.register_module(group_key=default_group, module_name='chatglm') +class TextGenerationMetric(Metric): + + def __init__(self, target_text='tgts', pred_text='preds'): + self.preds: List[str] = [] + self.tgts: List[str] = [] + self.rouge = Rouge() + self.target_text = target_text + self.pred_text = pred_text + + def add(self, outputs: Dict[str, List[str]], inputs: Dict[str, List[str]]): + ground_truths = inputs[self.target_text] + eval_results = outputs[self.pred_text] + for truth in ground_truths: + self.tgts.append(truth) + for result in eval_results: + self.preds.append(result) + + def _check(self, pred: str, tgt: str) -> bool: + + def remove_useless(string: str) -> str: + return string.replace(' ', '').replace('.', '') + + return len(remove_useless(pred)) != 0 and len(remove_useless(tgt)) != 0 + + def evaluate(self): + preds, labels = self.preds, self.tgts + if isinstance(preds, tuple): + preds = preds[0] + + score_dict = { + 'rouge-1': [], + 'rouge-2': [], + 'rouge-l': [], + 'bleu-4': [] + } + for pred, label in zip(preds, labels): + hypothesis = list(jieba.cut(pred)) + if len(hypothesis) == 0: + hypothesis = [''] + reference = list(jieba.cut(label)) + rouge = Rouge() + scores = rouge.get_scores(' '.join(hypothesis), + ' '.join(reference)) + result = scores[0] + + for k, v in result.items(): + score_dict[k].append(round(v['f'] * 100, 4)) + bleu_score = sentence_bleu( + [list(label)], + list(pred), + smoothing_function=SmoothingFunction().method3) + score_dict['bleu-4'].append(round(bleu_score * 100, 4)) + + for k, v in score_dict.items(): + score_dict[k] = float(np.mean(v)) + return score_dict + + def merge(self, other: 'TextGenerationMetric'): + self.preds.extend(other.preds) + self.tgts.extend(other.tgts) + + def __getstate__(self): + return self.preds, self.tgts + + def __setstate__(self, state): + self.__init__() + self.preds, self.tgts = state diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 07d4fd3d..f82bac08 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -165,6 +165,8 @@ class Models(object): doc2bot = 'doc2bot' peer = 'peer' llama = 'llama' + chatglm_6b = 'chatglm6b' + chatglm2_6b = 'chatglm2-6b' # audio models sambert_hifigan = 'sambert-hifigan' diff --git a/modelscope/models/cv/vision_efficient_tuning/backbone.py b/modelscope/models/cv/vision_efficient_tuning/backbone.py index 691e4440..e83fb958 100644 --- a/modelscope/models/cv/vision_efficient_tuning/backbone.py +++ b/modelscope/models/cv/vision_efficient_tuning/backbone.py @@ -191,7 +191,7 @@ class BlockPETL(nn.Module): self.prompt = None def forward(self, x): - if self.prompt is not None: + if self.prompt is not None and self.prompt_length and self.prompt_length > 0: x = self.prompt(x) x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) diff --git a/modelscope/models/multi_modal/efficient_diffusion_tuning/efficient_stable_diffusion.py b/modelscope/models/multi_modal/efficient_diffusion_tuning/efficient_stable_diffusion.py index 747aecd8..3f616297 100644 --- a/modelscope/models/multi_modal/efficient_diffusion_tuning/efficient_stable_diffusion.py +++ b/modelscope/models/multi_modal/efficient_diffusion_tuning/efficient_stable_diffusion.py @@ -19,8 +19,8 @@ from modelscope.metainfo import Models from modelscope.models import TorchModel from modelscope.models.builder import MODELS from modelscope.outputs import OutputKeys -from modelscope.tuners.control_sd_lora import ControlLoRATuner -from modelscope.tuners.sd_lora import LoRATuner +from modelscope.swift.control_sd_lora import ControlLoRATuner +from modelscope.swift.sd_lora import LoRATuner from modelscope.utils.checkpoint import save_checkpoint, save_configuration from modelscope.utils.config import Config from modelscope.utils.constant import ModelFile, Tasks diff --git a/modelscope/models/nlp/__init__.py b/modelscope/models/nlp/__init__.py index ebe081a9..c99f04ec 100644 --- a/modelscope/models/nlp/__init__.py +++ b/modelscope/models/nlp/__init__.py @@ -22,6 +22,8 @@ if TYPE_CHECKING: from .csanmt import CsanmtForTranslation from .canmt import CanmtForTranslation from .deberta_v2 import DebertaV2ForMaskedLM, DebertaV2Model + from .chatglm import ChatGLMForConditionalGeneration, ChatGLMTokenizer, ChatGLMConfig + from .chatglm2 import ChatGLM2ForConditionalGeneration, ChatGLM2Tokenizer, ChatGLM2Config from .gpt_neo import GPTNeoModel from .gpt2 import GPT2Model from .gpt3 import GPT3ForTextGeneration, DistributedGPT3 @@ -95,6 +97,14 @@ else: ['CodeGeeXForCodeTranslation', 'CodeGeeXForCodeGeneration'], 'glm_130b': ['GLM130bForTextGeneration'], 'deberta_v2': ['DebertaV2ForMaskedLM', 'DebertaV2Model'], + 'chatglm': [ + 'ChatGLMForConditionalGeneration', 'ChatGLMTokenizer', + 'ChatGLMConfig' + ], + 'chatglm2': [ + 'ChatGLM2ForConditionalGeneration', 'ChatGLM2Tokenizer', + 'ChatGLM2Config' + ], 'heads': ['TextClassificationHead'], 'hf_transformers': ['TransformersModel'], 'gpt2': ['GPT2Model'], diff --git a/modelscope/models/nlp/chatglm/__init__.py b/modelscope/models/nlp/chatglm/__init__.py new file mode 100644 index 00000000..2a0a073f --- /dev/null +++ b/modelscope/models/nlp/chatglm/__init__.py @@ -0,0 +1,46 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .configuration import ChatGLMConfig + from .tokenization import ChatGLMTokenizer + from .text_generation import ChatGLMForConditionalGeneration + from .quantization import ( + quantize, ) + +else: + _import_structure = { + 'configuration': ['ChatGLMConfig'], + 'text_generation': ['ChatGLMForConditionalGeneration'], + 'quantization': ['quantize'], + 'tokenization': [ + 'ChatGLMTokenizer', + ], + } + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__) diff --git a/modelscope/models/nlp/chatglm/configuration.py b/modelscope/models/nlp/chatglm/configuration.py new file mode 100644 index 00000000..18fdca0f --- /dev/null +++ b/modelscope/models/nlp/chatglm/configuration.py @@ -0,0 +1,101 @@ +""" ChatGLM model configuration """ + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class ChatGLMConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`~ChatGLMModel`]. + It is used to instantiate an ChatGLM model according to the specified arguments, defining the model + architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of + the ChatGLM-6B [THUDM/ChatGLM-6B](https://huggingface.co/THUDM/chatglm-6b) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used + to control the model outputs. Read the documentation from [`PretrainedConfig`] + for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 150528): + Vocabulary size of the ChatGLM-6B model. + Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`~ChatGLMModel`] or + [`~TFChatGLMModel`]. + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 28): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + inner_hidden_size (`int`, *optional*, defaults to 16384): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + max_sequence_length (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. + Typically set this to something large just in case (e.g., 512 or 1024 or 2048). + layernorm_epsilon (`float`, *optional*, defaults to 1e-5): + The epsilon used by the layer normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether the model should return the last key/values attentions (not used by all models). + Example: + + ```python + >>> from modelscope.models.nlp.chatglm.configuration import ChatGLMConfig + >>> from modelscope.models.nlp.chatglm.text_generation import ChatGLMModel + + >>> # Initializing a ChatGLM-6B THUDM/ChatGLM-6B style configuration + >>> configuration = ChatGLMConfig() + + >>> # Initializing a model from the THUDM/ChatGLM-6B style configuration + >>> model = ChatGLMModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` +""" + model_type = 'chatglm' + + def __init__(self, + vocab_size=150528, + hidden_size=4096, + num_layers=28, + num_attention_heads=32, + layernorm_epsilon=1e-5, + use_cache=False, + bos_token_id=150004, + eos_token_id=150005, + mask_token_id=150000, + gmask_token_id=150001, + pad_token_id=0, + max_sequence_length=2048, + inner_hidden_size=16384, + position_encoding_2d=True, + quantization_bit=0, + pre_seq_len=None, + prefix_projection=False, + **kwargs): + self.num_layers = num_layers + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.max_sequence_length = max_sequence_length + self.layernorm_epsilon = layernorm_epsilon + self.inner_hidden_size = inner_hidden_size + self.use_cache = use_cache + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.mask_token_id = mask_token_id + self.gmask_token_id = gmask_token_id + self.position_encoding_2d = position_encoding_2d + self.quantization_bit = quantization_bit + self.pre_seq_len = pre_seq_len + self.prefix_projection = prefix_projection + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs) diff --git a/modelscope/models/nlp/chatglm/quantization.py b/modelscope/models/nlp/chatglm/quantization.py new file mode 100644 index 00000000..9994d9c4 --- /dev/null +++ b/modelscope/models/nlp/chatglm/quantization.py @@ -0,0 +1,234 @@ +import base64 +import bz2 +import ctypes +from typing import List + +import torch +from torch.nn import Linear +from torch.nn.parameter import Parameter +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +try: + from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up + + class Kernel: + + def __init__(self, code: bytes, function_names: List[str]): + self.code = code + self._function_names = function_names + self._cmodule = LazyKernelCModule(self.code) + + for name in self._function_names: + setattr(self, name, KernelFunction(self._cmodule, name)) + + quantization_code = '$QlpoOTFBWSZTWU9yuJUAQHN//////////f/n/8/n///n//bt4dTidcVx8X3V9FV/92/v4B7/AD5FBQFAAAChSgKpFCFAFVSigUAAAEKhSgUUqgFBKigqVREQAABQBQIANDTTIGI00BkZBkNGE0A0BkBkGQGRkaNAaAGQNBoGgDIAAYIGTI0DQAQAaGmmQMRpoDIyDIaMJoBoDIDIMgMjI0aA0AMgaDQNAGQAAwQMmRoGgAgA0NNMgYjTQGRkGQ0YTQDQGQGQZAZGRo0BoAZA0GgaAMgABggZMjQNABABoaaZAxGmgMjIMhowmgGgMgMgyAyMjRoDQAyBoNA0AZAADBAyZGgaAAmqU1NEgJqnptU/Sn4jRR6J6epk2pqb1Q/SgAPUGgyNNGjQ2SBpoAZAAGg0NB6mgDIAAAAA2oaApSREBNAARhGiYEaEwU8pvImlP0k2aam1GaGqbFNM1MHpTwmkepmyU9R6nqPKekHqNNPUxNGhp6n6p6QaZ6o9TG1GMqcoV9ly6nRanHlq6zPNbnGZNi6HSug+2nPiZ13XcnFYZW+45W11CumhzYhchOJ2GLLV1OBjBjGf4TptOddTSOcVxhqYZMYwZXZZY00zI1paX5X9J+b+f4e+x43RXSxXPOdquiGpduatGyXneN696M9t4HU2eR5XX/kPhP261NTx3JO1Ow7LyuDmeo9a7d351T1ZxnvnrvYnrXv/hXxPCeuYx2XsNmO003eg9J3Z6U7b23meJ4ri01OdzTk9BNO96brz+qT5nuvvH3ds/G+m/JcG/F2XYuhXlvO+jP7U3XgrzPN/lr8Sf1n6j4j7jZs+s/T0tNaNNYzTs12rxjwztHlnire3Nzc3N1wuBwOBwXBvZfoHpD7rFmR99V5vj3aXza3xdBbXMalubTg/jIv5dfAi54Pdc75j4z412n3Npj3Ld/ENm7a3b/Cod6h/ret1/5vn/C+l+gdslMvgPSLJ8d8q+U66fevYn/tW1chleEtNTGlcHCbLRlq0tHzF5tsbbZZfHjjLgZu42XCuC3NrdjTasZGNzgxPIrGqp7r3p7L2p5XjnpPSmTd5XtzqnB6U87zzg1Ol0zd0zsLszxR6lkxp35u6/teL0L0W922cR7Lu1lpL9CsHirzuM2T+BgsyViT6LHcm0/Vr6U/7LGGyJeqTEjt0PHWhF5mCT7R9mtlDwriYv0Tyr/OxYt6qp5r0mPVT0608TqnqMZaarU2nFwrTzzlrs1ed7z1ux60wyr4ydCaTi3enW8x68x0zU7tXSlcmPSW1mGpWJMg4zmPC2lK96tp0OE80y4MfEvnZj8zGluR6b22ki1Ou9V2nCd9xovcPvcYMZYy0lvN60ScZ45vN6yeCeeXFb1lVjnnCar5fwXwE2bzJ4HI1XVPXfXZMm44GUsMpYsmLB65TuVdm0cl0b+i/wGNN66XjeV7zuPpHcnK/juhhjdfId5jMdE5nN0dGmmm2zZs2cexD5n9p/dY352XsvXHaZNWWsmmS1atjR452nYudzvqv2HMRyvNNnlMcDl3R2+yx2uVrBubTW9icHDVtbNXlZm7jma1rM4VurZZd2y6nUau7ZXZ7bVU+mnoOVxZGMrVmvX60605JwmzGZhhhjTWtaaaMaaGTGmNMZasY0iX8VMUl8eepaIrzGSpemWOQyZORk2bNpjUybMmxqYmknCGCFynutfksaZpjTNMaaatM0xsxcGR0sociNqxNSmhhR1ZJPbsn8qyF0t2qH6iYBclclalbtTTcHTDsPaX6rlnElph2Jyumumtynv2Kk8GI7rsvXbIcJgHJOSaSXnnGaI3m87RtVXJOZ/YtgdTE6Wpha6ZlE8ayXkef1fh602r2WwvfMXtMdLlkfnLFdYYwYso+bWqm7yJqHXZGw2nrS5ZanSYnWlxBxMF1V940K2wdrI7R6OYf7DGGamMmTSbRhlS45xmVOumF1EyPCmHrrN8wwZOOrdNtLeMtzFzDlWnfTBxMk2NaXIZHBYxYLD4w8yju0ao65Vz1OIXoS9dLanwCe1PWrYuWMqf1if1z2k2yYfKJ741PDgno1ZQ8DRqvUny3mNoWTzGO6m1DkrJI8JiR5cSd+vZdGOO8nrMoc5+NDUFsMSXaZJeNlMmGLtJsovOsUp7I9S5VojKxF6bTVEelXqlfJobQr3LozSh2Jk7VcrVMfhXqszGWMzNqGhqZY0OadxkyyMssKugZR0KNFXBHlqwmJgTE/BNVMk6ItJXZMR0H47GpXv/DMOvNkmVuaV1PRfEdxuqc7Hcd+ZV/zTLaRxWk0nl9CdCeM6mn5rstHIBcpiuwmUZXeq81DacHI2rmrZ5SuE5mOZd6LQrZg9mx32TprA8BMo5jKN6yLTCi3WzQaZSuhzTtM1fUTGVpG8Tw+KXI0tjEpiWxtLYynOlktSbVlaI5kxP8TDH8kx50xoxi5KcA4pcja8KWLRlO/Ks6q06ergnvm1ca3Tq8Uw7LTUsmWyctXPWmpitl/uvGcWTGXGuAXDfhqazGmjkxcJW5hMMMMpYsXl2TZYtVOddG3XCarUt6Ptq9CZXSNzyuRzqRZOjsxdBbFVz6OA5HI43r1jityVlVpVkxmOsyaYWE1NTGq1sOVh36mHMcxtSvcy70edG0ZGR3I1Go1GRlV7mWWo1G0ZGRqlvH40l7o4m5xMWLLLYyNjnqc8556mdPqLJ31n/1nWOncxzG1tizrHs/Z+d2vP/B/l8wdJ6rHUn2nbbDq4p6htFtYzMMMTaZis1K5GKzGNmxhmUx2DDlZ/qNnIx41xnaMfCZWYaZWtNLTNW8ND4Fw1MyZOCdM428suKG1ehW8TesOydg7J+YYcD4cYR+8dFK6M4E3HM9ZfRNNL+Sn6rsl4DsrDl2HpPCnfxjGXtbZtYys1ttlyJ4T+BvexjGWRjMszK4Jpc77D3GyuVD7q0+G8m9G+2+rGm7cOR2y7FdtY2XUYx/oNlfRYxhMYyYZkyyg55enna9Kt/FFi6GMMwYwdwxWgxGMLKYmUyGExTKMZkMFhkymKuh0NOBNnBu+23LdwDoZYYzGGMxtORaTU1pjTGWTTGGtMrNWUsyyTTLLG1qy2ZjbK2DBllWqxMtBMaYZQmcE7zvvRcTkclUwdkxTaSdyySt/7fpL+T1v516Ji97fwr5JbLu305zMn5+GMTTZ9F+y7ExwmGVfG44yxn3dLv6l5i+Wth1jCrDq21nW9LqvvDzz3Vf3LLH/O/32TJ/erx3bXftO4eF+G956D952K/An4NfvOpjFjExjevP/UmE0fIoZXx6/w6lX/no3D0bLt+ixjieBM6ksRd0yB4Lt2SwYNE+gd1detlZWUnpiZfGfFaK+4PyCa/v18V8X75pe9fLXzp7l3VjF76vWZmHwGz1IZNWT7b8yddJ4q5kyrVdfru6atWc7bVYztL9Jf4GXvT+Y8m9/YsXP6H018a8D4XVOqvfzqeR+6yZOD8dPv0+U7/q5Pl+2dNb0MjzGVH5p6MNQ7cOWvw62U9aHE8DprDek+McLyvDz+te+9Zhq5+YTruufMcWMabqysTmZVWjKPfnK0wyVcrsuhjZRdLkHNvD72b9abriOSGIxiLixMOoalNPXzy+wT/tf+U6HHONfsz+xe8ufHBdQWWGWLA9if0rsnmrxK5LvRZQeWsTCsrmOYy8VteVfuRfcVTtDLItLIsMYxZLdU/DbtSemxF6Z6Zo5WBXE4tFdCyVMMXMTEMZXVlS6Xec2T4e0tHsRcEuWshcJ2YsNF5rUx1E8ifCq6Z+ZP7qdCeu/aTwFd53l16/o0NOw6O3dLavP4Hbi4RdmuDk6DoYaninC0+o4uZjbJ7Rxeu0/FbuFg+q7DVS6fQe0rZ6NDGUNNU6DEqOaLTicKnYZMnBWruljQxoaS3dZhocDge0bSTyOvdAbG5hxe2xji7E/L55xX13wWNDi6HCekcFxfCPGxY0MXC+s7afWaMdDyjyr+o8Rudm/NabOZvdl274zH4f5XK9z6On1Pe/K5TdPAslg77BjuO6Y3eO7GqvOPG/stknp1leyvLL0Z7bl9I4noMvLkzytLhWYzrOZzLXCORe028rORzOg4N/L0HlMOQ3Pgmnbb6KczlabORpu980q37TBqRu0/p3PO6234Bl03Ynuz+9W7gnsEcmvYaYY3aMYY0wx3pYd+ujsXauWdaY5Xkbtl23fPzFHiDB/QMo0yFjBllYxTQYYyxkrwn7JufwJ/PfgJ+C83X69ni6zvXcnyXabv0ncbLwsceS+RNlyN2mnneJtX0ngYO0+e+0+UnA+Wch3ji8hj5an4h+i6XBySU4n+R0roVcbw5yvHrmr4Yw8Y7x6c+9POPYHI5HI5HI5HI5HGXGww4nE4nrVyOR8XeqPEO7PLOiukYa3Novk5hV4cdtYZLI93e+uxff2jRo0aNGjRo0aNG1bVtW1dy3m83m8+tQ5ZzHw3nObwOu8La9Rc1dtkdS8A3eTk823tnktXWlxN6Oixe06zrN70Isd9jiOgZFq9yfkPqP/SLhN2Myl8jDM43bl1nbcb4cO57jlh8Jow6pzXZdL4dyODTuuhu77FyO27DdwdRxmvO+O+3N2+BdqyTwLHVczDVY4UPE4O66/ZO2cx1LFzVdSXtF7G4HMbrauOHRw6c8FdZ5m9fHZHYZXfTlZquyynSyTTKke6vcffSD9pzPA/G7n7jxPmuhc1DHMynPMrGL6AdewYmwu5ko+UUyTwrMv27rPH1v1nGqd87+p6N6LU8k3NEng53xXyHS97+44OSg/sy/hn+Se6yfYNjW0/uTgP+PvWYzLMmjhcLB/gGpri6H83/84eUXWT6T9Hsv7785z/7z4icpW+zfXypuR7rx/gMdZb1/wC678pcs8/2a3mDitGHxl9mfPlll5MafWWqxk/eYuTDgcNMzDGWLWvsuglNxs53GtN6uWpktlW1tZZYcuinMMWmnNnJydze3b2Y1McBxrBkXw799izLMZZYyy0TkbsGM4p03S2uVu5s/XXUdSdec6smVxZYYGpVmT8A+8ajuEyV5FatkvVru2x6uxGXXbH4A+jvgP4GMYy3iPLXzq/6z65+E005ey+cwMZD3fZcqc6xpjTFjQ0P3U+e++cPYmTIwj0nrK5NPTfl3WvpfLtXDcb2HQMudYOxFXQBor4L4T6vrOauFctYXJQ++NUWmJe5bmx1jDiZS1dTqWxo4GR8jm3fttpmPHppk9PEyv4/y8/sO07XacOmcqc0x2Vi9BvNJvN5oW8x4mOsydpidRxMYJPx06m1bqPzq9KtK8sxXNXFodD/+MYYaJTLwOhc9brCsV18oOR1i4tXChyTkq4lf4y1Ke+9axjDHqs1mfBbMXuP4Hzi+X7t8vzv7bHerrUPgPCxhjre4fXdfLNtNM+Jd+Zdh8xd8wP87uNPoPgv4W7/5P2BuxfsMabNnMnza+54Pdi5U671GPZY8CehX8Voeoo7FHpkeEc6715FwHZrIrUrHaviPUbPZHND+IhczrP6FcYvhOZ0Di/ETt0OI+YwNWR9r7tpf6WDeZKZDB1+z2IthOl1mPyb5FluvEx9h9d0NnM0Y1XPFkWIsk1WotJ0PBMmkvjvQTd0e71tfeV+8r8lQ/tpzpsmxJ+InrI/dj2UajUajVTUajatRqNRtGo1Go1Go4wjeMpZFMVV9CHbofPraLsJ3JpWV2XOoanCuFky4y3PPNxucK2uKC1Lbdb1eo+m5XomN6HfeZsabHLHRX/K+offtNGGmHWctcVcG44MdSqsOLY9VzX+Zxfxn2HPdWTpzWvkrtJ8M5zorrKcquRytJ5N5DZmcaW02l76nWO+BqPXm1A2Ry/0q71dH/mqrqeFjkYxjEXtsX8qubTk67rGycyqsdm4tZx5D6D5hhi0waaWmiaMP81Yjii5qxPlPuU/GfTL1Y5E6Jyfiq63qTa39A4J0sOGDgO9WF9bOXl0XfPRbsY2bPNKPy1YrFYrFYmRhhlTIyMjJWJYZHXuCXI8OoXsvfljGLFicNifpp2XunoPiG1wtx3p1Tah+/DD66OnVtVXP9rKbVxOnL0tR/rHtqB5UDErUVcl11D4qqvjpOcxX7armUNJB3LpW6bxVvD08e8h3odKKvyCFZBdSh2FVcST9xV3n3T8t1j7Kr9qgrqXg+13Pt5U7JCvFXVIV1YG5lRhkVYZJYYDDD4KOIMoHCp26WS8GB7uBh2zIdgq/PKyInjV2STShuoapUdCpX1yTwqq/z1VvET7Kh5nVPkO8YyxjLt2MaaMmWTLQvx3qnzltnXW0p2jxgbEtSny/Osv8Y9pLMXYoHVPAhkVdWVeODhR6q9/Sxe2liwwZWMVvFXfRkeIDxAePUPIrdJ4ey6yquzH+PD/bUOWAu05qVHtFd8rrKHSoeNIOUqrYr3FXyToqfYJgwmJdKpXXOwYYegNNGMzfZPp/t3t/DVs4zjNTN61rRqaWaa4NYbRjTa0tWwy2Y2tGN8ZO8ofNKq4j9SL7I+cSm4/6ovLV5HNXLI0jJidwrtk6ynCaP6Z++GjRlWS3tLeW129Mi9evxU9mtz6s5J3Z7M2ngTgnKvmpomxpaLCzPfmx0JWE+m3NLDDGOX47RctdYYNK5jakdqLkRlI39n590T5zctGSwwZZDJj6kW8XSi6ot2MmWWJ0DUT3nuvebBudScjZ79g8cWJ8av0k+/bE5WKd5MdbFpbDVMxu1DVMmtNZGJvq1mtRbn6M+g/kP0FwDwr7quZs7xosNGpbscyxhhd9TyJyFwbLcxlTasg75vW7TsV5K7ji44XPMMrdoj+Y3rT0Hie62nlYV/pwczzOmdLqLhYkzGMzCZWGMQzGMSsZYY6Di1t4nlJ+Em63mJxrVLxPbYxNEdgc1dU2iOKyoYYWjNrEeHTYybVk0atSa7ehuwsWMWTqn1TrnS6hYsi71d1+s+k+ic70e20fzE/VaTdxT9ZtU4GIXdeNx3X77guYYfpHeTQjaMX6brOu4OY4K7Y2d9mbHarI5ox3p4GpJ2Vd/Tst60f7j999pppjR+Q/Qf8J/VaORs3cji7FfFuN61+ui9s8hix1OCh5KGVV23BPXvZfz3CLyHpix+exi8z/KnCnosY2eunor+cxyPO/xJ0vKey9OvE9VjqaYu0x3Z3jd6o2b1T12D+F8l232lwaaacD5LE8LBxu7WTlbWraWpew8Xexjel3E+wWD4APITdNqR8F3R3T0lunCQ4GaE9R37DxeCYfcHi4xci5ovKfxVs55y2hf+65E/Xdp6jR5nrebTmi5incpkyOjs50JvrZwstbbW6kfuuQw+2mykf/EXNFzxfKTrxew929TR6bWnGL//F3JFOFCQT3K4lQ' # noqa + + kernels = Kernel( + bz2.decompress(base64.b64decode(quantization_code)), + [ + 'int4WeightCompression', + 'int4WeightExtractionFloat', + 'int4WeightExtractionHalf', + 'int8WeightExtractionFloat', + 'int8WeightExtractionHalf', + ], + ) +except Exception as exception: + kernels = None + logger.warning('Failed to load cpm_kernels:' + str(exception)) + + +class W8A16Linear(torch.autograd.Function): + + @staticmethod + def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, + scale_w: torch.Tensor, weight_bit_width): + ctx.inp_shape = inp.size() + ctx.weight_bit_width = weight_bit_width + out_features = quant_w.size(0) + inp = inp.contiguous().view(-1, inp.size(-1)) + weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width) + ctx.weight_shape = weight.size() + output = inp.mm(weight.t()) + ctx.save_for_backward(inp, quant_w, scale_w) + return output.view(*(ctx.inp_shape[:-1] + (out_features, ))) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + inp, quant_w, scale_w = ctx.saved_tensors + weight = extract_weight_to_half(quant_w, scale_w, ctx.weight_bit_width) + grad_output = grad_output.contiguous().view(-1, weight.size(0)) + grad_input = grad_output.mm(weight) + grad_weight = grad_output.t().mm(inp) + return grad_input.view(ctx.inp_shape), grad_weight.view( + ctx.weight_shape), None, None + + +def compress_int4_weight(weight: torch.Tensor): # (n, m) + with torch.cuda.device(weight.device): + n, m = weight.size(0), weight.size(1) + assert m % 2 == 0 + m = m // 2 + out = torch.empty(n, m, dtype=torch.int8, device='cuda') + stream = torch.cuda.current_stream() + + gridDim = (n, 1, 1) + blockDim = (min(round_up(m, 32), 1024), 1, 1) + + if kernels is None: + raise RuntimeError( + 'kernels is None, please check whether it is correctly initialized.' + ) + kernels.int4WeightCompression( + gridDim, + blockDim, + 0, + stream, + [ + ctypes.c_void_p(weight.data_ptr()), + ctypes.c_void_p(out.data_ptr()), + ctypes.c_int32(n), + ctypes.c_int32(m) + ], + ) + return out + + +def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, + source_bit_width: int): + if kernels is None: + raise RuntimeError( + 'kernels is None, please check whether it is correctly initialized.' + ) + if source_bit_width == 8: + func = kernels.int8WeightExtractionHalf + elif source_bit_width == 4: + func = kernels.int4WeightExtractionHalf + else: + assert False, 'Unsupported bit-width' + + with torch.cuda.device(weight.device): + n, m = weight.size(0), weight.size(1) + out = torch.empty( + n, m * (8 // source_bit_width), dtype=torch.half, device='cuda') + stream = torch.cuda.current_stream() + + gridDim = (n, 1, 1) + blockDim = (min(round_up(m, 32), 1024), 1, 1) + + func( + gridDim, + blockDim, + 0, + stream, + [ + ctypes.c_void_p(weight.data_ptr()), + ctypes.c_void_p(scale_list.data_ptr()), + ctypes.c_void_p(out.data_ptr()), + ctypes.c_int32(n), + ctypes.c_int32(m), + ], + ) + return out + + +class QuantizedLinear(Linear): + + def __init__(self, + weight_bit_width: int, + weight_tensor=None, + bias_tensor=None, + empty_init=False, + *args, + **kwargs): + super(QuantizedLinear, self).__init__(*args, **kwargs) + self.weight_bit_width = weight_bit_width + + shape = self.weight.shape + del self.weight + + if weight_tensor is None or empty_init: + self.weight = torch.empty( + shape[0], + shape[1] * weight_bit_width // 8, + dtype=torch.int8, + device=kwargs['device']) + self.weight_scale = torch.empty( + shape[0], dtype=kwargs['dtype'], device=kwargs['device']) + else: + self.weight_scale = (weight_tensor.abs().max(dim=-1).values / ( + (2**(weight_bit_width - 1)) - 1)).half() # noqa + self.weight = torch.round( + weight_tensor / self.weight_scale[:, None]).to(torch.int8) + if weight_bit_width == 4: + self.weight = compress_int4_weight(self.weight) + + self.weight = Parameter( + self.weight.to(kwargs['device']), requires_grad=False) + self.weight_scale = Parameter( + self.weight_scale.to(kwargs['device']), requires_grad=False) + if bias_tensor is not None: + self.bias = Parameter( + bias_tensor.to(kwargs['device']), requires_grad=False) + else: + self.bias = None + + def forward(self, input): + output = W8A16Linear.apply(input, self.weight, self.weight_scale, + self.weight_bit_width) + if self.bias is not None: + output = output + self.bias + return output + + +def quantize(model, weight_bit_width, empty_init=False, **kwargs): + """Replace fp16 linear with quantized linear""" + + for layer in model.layers: + layer.attention.query_key_value = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight_tensor=layer.attention.query_key_value.weight.to( + torch.cuda.current_device()), + bias_tensor=layer.attention.query_key_value.bias, + in_features=layer.attention.query_key_value.in_features, + out_features=layer.attention.query_key_value.out_features, + bias=True, + dtype=torch.half, + device=layer.attention.query_key_value.weight.device, + empty_init=empty_init) + layer.attention.dense = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight_tensor=layer.attention.dense.weight.to( + torch.cuda.current_device()), + bias_tensor=layer.attention.dense.bias, + in_features=layer.attention.dense.in_features, + out_features=layer.attention.dense.out_features, + bias=True, + dtype=torch.half, + device=layer.attention.dense.weight.device, + empty_init=empty_init) + layer.mlp.dense_h_to_4h = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight_tensor=layer.mlp.dense_h_to_4h.weight.to( + torch.cuda.current_device()), + bias_tensor=layer.mlp.dense_h_to_4h.bias, + in_features=layer.mlp.dense_h_to_4h.in_features, + out_features=layer.mlp.dense_h_to_4h.out_features, + bias=True, + dtype=torch.half, + device=layer.mlp.dense_h_to_4h.weight.device, + empty_init=empty_init) + layer.mlp.dense_4h_to_h = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight_tensor=layer.mlp.dense_4h_to_h.weight.to( + torch.cuda.current_device()), + bias_tensor=layer.mlp.dense_4h_to_h.bias, + in_features=layer.mlp.dense_4h_to_h.in_features, + out_features=layer.mlp.dense_4h_to_h.out_features, + bias=True, + dtype=torch.half, + device=layer.mlp.dense_4h_to_h.weight.device, + empty_init=empty_init) + return model diff --git a/modelscope/models/nlp/chatglm/text_generation.py b/modelscope/models/nlp/chatglm/text_generation.py new file mode 100644 index 00000000..ff32c86d --- /dev/null +++ b/modelscope/models/nlp/chatglm/text_generation.py @@ -0,0 +1,1571 @@ +""" PyTorch ChatGLM model. """ + +import copy +import math +import os +import re +import sys +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss, LayerNorm +from torch.nn.utils import skip_init +from transformers.generation.logits_process import LogitsProcessor +from transformers.generation.utils import (GenerationConfig, + LogitsProcessorList, ModelOutput, + StoppingCriteriaList) +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithPast) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import (add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, logging) + +from modelscope.metainfo import Models +from modelscope.models import MODELS, Model, TorchModel +from modelscope.outputs import OutputKeys +from modelscope.utils.constant import Tasks +from .configuration import ChatGLMConfig +from .tokenization import ChatGLMTokenizer + +# flags required to enable jit fusion kernels + +if sys.platform != 'darwin': + torch._C._jit_set_profiling_mode(False) + torch._C._jit_set_profiling_executor(False) + torch._C._jit_override_can_fuse_on_cpu(True) + torch._C._jit_override_can_fuse_on_gpu(True) + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = 'THUDM/ChatGLM-6B' +_CONFIG_FOR_DOC = 'ChatGLM6BConfig' + +CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [ + 'THUDM/chatglm-6b', + # See all ChatGLM-6B models at https://huggingface.co/models?filter=chatglm +] + + +class InvalidScoreLogitsProcessor(LogitsProcessor): + + def __call__(self, input_ids: torch.LongTensor, + scores: torch.FloatTensor) -> torch.FloatTensor: + if torch.isnan(scores).any() or torch.isinf(scores).any(): + scores.zero_() + scores[..., 5] = 5e4 + return scores + + +def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + 'Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see ' + 'https://www.tensorflow.org/install/ for installation instructions.' + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f'Converting TensorFlow checkpoint from {tf_path}') + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f'Loading TF weight {name} with shape {shape}') + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split('/') + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any(n in [ + 'adam_v', 'adam_m', 'AdamWeightDecayOptimizer', + 'AdamWeightDecayOptimizer_1', 'global_step' + ] for n in name): + logger.info(f"Skipping {'/'.join(name)}") + continue + pointer = model + for m_name in name: + if re.fullmatch(r'[A-Za-z]+_\d+', m_name): + scope_names = re.split(r'_(\d+)', m_name) + else: + scope_names = [m_name] + if scope_names[0] == 'kernel' or scope_names[0] == 'gamma': + pointer = getattr(pointer, 'weight') + elif scope_names[0] == 'output_bias' or scope_names[0] == 'beta': + pointer = getattr(pointer, 'bias') + elif scope_names[0] == 'output_weights': + pointer = getattr(pointer, 'weight') + elif scope_names[0] == 'squad': + pointer = getattr(pointer, 'classifier') + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == '_embeddings': + pointer = getattr(pointer, 'weight') + elif m_name == 'kernel': + array = np.transpose(array) + try: + assert ( + pointer.shape == array.shape + ), f'Pointer shape {pointer.shape} and array shape {array.shape} mismatched' + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f'Initialize PyTorch weight {name}') + pointer.data = torch.from_numpy(array) + return model + + +class PrefixEncoder(torch.nn.Module): + """ + The torch.nn model to encode the prefix + Input shape: (batch-size, prefix-length) + Output shape: (batch-size, prefix-length, 2*layers*hidden) + """ + + def __init__(self, config): + super().__init__() + self.prefix_projection = config.prefix_projection + if self.prefix_projection: + # Use a two-layer MLP to encode the prefix + self.embedding = torch.nn.Embedding(config.pre_seq_len, + config.hidden_size) + self.trans = torch.nn.Sequential( + torch.nn.Linear(config.hidden_size, config.hidden_size), + torch.nn.Tanh(), + torch.nn.Linear(config.hidden_size, + config.num_layers * config.hidden_size * 2)) + else: + self.embedding = torch.nn.Embedding( + config.pre_seq_len, config.num_layers * config.hidden_size * 2) + + def forward(self, prefix: torch.Tensor): + if self.prefix_projection: + prefix_tokens = self.embedding(prefix) + past_key_values = self.trans(prefix_tokens) + else: + past_key_values = self.embedding(prefix) + return past_key_values + + +@torch.jit.script +def gelu_impl(x): + """OpenAI's gelu implementation.""" + return 0.5 * x * ( + 1.0 + torch.tanh(0.7978845608028654 * x * # noqa + (1.0 + 0.044715 * x * x))) # noqa + + +def gelu(x): + return gelu_impl(x) + + +class RotaryEmbedding(torch.nn.Module): + + def __init__(self, dim, base=10000, precision=torch.half, learnable=False): + super().__init__() + inv_freq = 1. / (base**(torch.arange(0, dim, 2).float() / dim)) + inv_freq = inv_freq.half() + self.learnable = learnable + if learnable: + self.inv_freq = torch.nn.Parameter(inv_freq) + self.max_seq_len_cached = None + else: + self.register_buffer('inv_freq', inv_freq) + self.max_seq_len_cached = None + self.cos_cached = None + self.sin_cached = None + self.precision = precision + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + pass + + def forward(self, x, seq_dim=1, seq_len=None): + if seq_len is None: + seq_len = x.shape[seq_dim] + if self.max_seq_len_cached is None or ( + seq_len > self.max_seq_len_cached): # noqa + self.max_seq_len_cached = None if self.learnable else seq_len + t = torch.arange( + seq_len, device=x.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum('i,j->ij', t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + if self.precision == torch.bfloat16: + emb = emb.float() + + # [sx, 1 (b * np), hn] + cos_cached = emb.cos()[:, None, :] + sin_cached = emb.sin()[:, None, :] + if self.precision == torch.bfloat16: + cos_cached = cos_cached.bfloat16() + sin_cached = sin_cached.bfloat16() + if self.learnable: + return cos_cached, sin_cached + self.cos_cached, self.sin_cached = cos_cached, sin_cached + return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...] + + def _apply(self, fn): + if self.cos_cached is not None: + self.cos_cached = fn(self.cos_cached) + if self.sin_cached is not None: + self.sin_cached = fn(self.sin_cached) + return super()._apply(fn) + + +def rotate_half(x): + x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] + return torch.cat( + (-x2, x1), + dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions + + +@torch.jit.script +def apply_rotary_pos_emb_index(q, k, cos, sin, position_id): + # position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn] + cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \ + F.embedding(position_id, sin.squeeze(1)).unsqueeze(2) + q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + ( + rotate_half(k) * sin) + return q, k + + +def attention_fn( + self, + query_layer, + key_layer, + value_layer, + attention_mask, + hidden_size_per_partition, + layer_id, + layer_past=None, + scaling_attention_score=True, + use_cache=False, +): + if layer_past is not None: + past_key, past_value = layer_past[0], layer_past[1] + key_layer = torch.cat((past_key, key_layer), dim=0) + value_layer = torch.cat((past_value, value_layer), dim=0) + + # seqlen, batch, num_attention_heads, hidden_size_per_attention_head + seq_len, b, nh, hidden_size = key_layer.shape + + if use_cache: + present = (key_layer, value_layer) + else: + present = None + + query_key_layer_scaling_coeff = float(layer_id + 1) + if scaling_attention_score: + query_layer = query_layer / ( + math.sqrt(hidden_size) * query_key_layer_scaling_coeff) + + # =================================== + # Raw attention scores. [b, np, s, s] + # =================================== + + # [b, np, sq, sk] + output_size = (query_layer.size(1), query_layer.size(2), + query_layer.size(0), key_layer.size(0)) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view(output_size[2], + output_size[0] * output_size[1], -1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], + -1) + + matmul_result = torch.zeros( + 1, + 1, + 1, + dtype=query_layer.dtype, + device=query_layer.device, + ) + + matmul_result = torch.baddbmm( + matmul_result, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=1.0, + ) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + if self.scale_mask_softmax: + self.scale_mask_softmax.scale = query_key_layer_scaling_coeff + attention_probs = self.scale_mask_softmax(attention_scores, + attention_mask.contiguous()) + else: + if not (attention_mask == 0).all(): + # if auto-regressive, skip + attention_scores.masked_fill_(attention_mask, -10000.0) + dtype = attention_scores.dtype + attention_scores = attention_scores.float() + attention_scores = attention_scores * query_key_layer_scaling_coeff + + attention_probs = F.softmax(attention_scores, dim=-1) + + attention_probs = attention_probs.type(dtype) + + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sk, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(1), value_layer.size(2), + query_layer.size(0), value_layer.size(3)) + + # change view [sk, b * np, hn] + value_layer = value_layer.view( + value_layer.size(0), output_size[0] * output_size[1], -1) + + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], + output_size[2], -1) + + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) + + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + ( + hidden_size_per_partition, ) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, present, attention_probs) + + return outputs + + +class SelfAttention(torch.nn.Module): + + def __init__(self, + hidden_size, + num_attention_heads, + layer_id, + hidden_size_per_attention_head=None, + bias=True, + params_dtype=torch.float, + position_encoding_2d=True): + super(SelfAttention, self).__init__() + + self.layer_id = layer_id + self.hidden_size = hidden_size + self.hidden_size_per_partition = hidden_size + self.num_attention_heads = num_attention_heads + self.num_attention_heads_per_partition = num_attention_heads + self.position_encoding_2d = position_encoding_2d + self.rotary_emb = RotaryEmbedding( # noqa + self.hidden_size // # noqa + (self.num_attention_heads * 2) if position_encoding_2d else # noqa + self.hidden_size // self.num_attention_heads, # noqa + base=10000, # noqa + precision=torch.half, # noqa + learnable=False, # noqa + ) # noqa + + self.scale_mask_softmax = None + + if hidden_size_per_attention_head is None: + self.hidden_size_per_attention_head = hidden_size // num_attention_heads + else: + self.hidden_size_per_attention_head = hidden_size_per_attention_head + + self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head + + # Strided linear layer. + self.query_key_value = skip_init( + torch.nn.Linear, + hidden_size, + 3 * self.inner_hidden_size, + bias=bias, + dtype=params_dtype, + ) + + self.dense = skip_init( + torch.nn.Linear, + self.inner_hidden_size, + hidden_size, + bias=bias, + dtype=params_dtype, + ) + + @staticmethod + def attention_mask_func(attention_scores, attention_mask): + attention_scores.masked_fill_(attention_mask, -10000.0) + return attention_scores + + def split_tensor_along_last_dim(self, + tensor, + num_partitions, + contiguous_split_chunks=False): + """Split a tensor along its last dimension. + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = tensor.size()[last_dim] // num_partitions + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + def forward( + self, + hidden_states: torch.Tensor, + position_ids, + attention_mask: torch.Tensor, + layer_id, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + """ + hidden_states: [seq_len, batch, hidden_size] + attention_mask: [(1, 1), seq_len, seq_len] + """ + + # [seq_len, batch, 3 * hidden_size] + mixed_raw_layer = self.query_key_value(hidden_states) + + # [seq_len, batch, 3 * hidden_size] --> [seq_len, batch, num_attention_heads, + # 3 * hidden_size_per_attention_head] + new_tensor_shape = mixed_raw_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + mixed_raw_layer = mixed_raw_layer.view(*new_tensor_shape) + + # [seq_len, batch, num_attention_heads, hidden_size_per_attention_head] + (query_layer, key_layer, + value_layer) = self.split_tensor_along_last_dim(mixed_raw_layer, 3) + + if self.position_encoding_2d: + q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1)) + k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1)) + cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1) + position_ids, block_position_ids = position_ids[:, 0, :].transpose(0, 1).contiguous(), \ + position_ids[:, 1, :].transpose(0, 1).contiguous() + q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids) + q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, + block_position_ids) + query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1)) + key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1)) + else: + position_ids = position_ids.transpose(0, 1) + cos, sin = self.rotary_emb( + value_layer, seq_len=position_ids.max() + 1) + # [seq_len, batch, num_attention_heads, hidden_size_per_attention_head] + query_layer, key_layer = apply_rotary_pos_emb_index( + query_layer, key_layer, cos, sin, position_ids) + + # [seq_len, batch, hidden_size] + context_layer, present, attention_probs = attention_fn( + self=self, + query_layer=query_layer, + key_layer=key_layer, + value_layer=value_layer, + attention_mask=attention_mask, + hidden_size_per_partition=self.hidden_size_per_partition, + layer_id=layer_id, + layer_past=layer_past, + use_cache=use_cache) + + output = self.dense(context_layer) + + outputs = (output, present) + + if output_attentions: + outputs += (attention_probs, ) + + return outputs # output, present, attention_probs + + +class GEGLU(torch.nn.Module): + + def __init__(self): + super().__init__() + self.activation_fn = F.gelu + + def forward(self, x): + # dim=-1 breaks in jit for pt<1.10 + x1, x2 = x.chunk(2, dim=(x.ndim - 1)) + return x1 * self.activation_fn(x2) + + +class GLU(torch.nn.Module): + + def __init__(self, + hidden_size, + inner_hidden_size=None, + layer_id=None, + bias=True, + activation_func=gelu, + params_dtype=torch.float): + super(GLU, self).__init__() + self.layer_id = layer_id + self.activation_func = activation_func + + # Project to 4h. + self.hidden_size = hidden_size + if inner_hidden_size is None: + inner_hidden_size = 4 * hidden_size + self.inner_hidden_size = inner_hidden_size + self.dense_h_to_4h = skip_init( + torch.nn.Linear, + self.hidden_size, + self.inner_hidden_size, + bias=bias, + dtype=params_dtype, + ) + # Project back to h. + self.dense_4h_to_h = skip_init( + torch.nn.Linear, + self.inner_hidden_size, + self.hidden_size, + bias=bias, + dtype=params_dtype, + ) + + def forward(self, hidden_states): + """ + hidden_states: [seq_len, batch, hidden_size] + """ + + # [seq_len, batch, inner_hidden_size] + intermediate_parallel = self.dense_h_to_4h(hidden_states) + + intermediate_parallel = self.activation_func(intermediate_parallel) + + output = self.dense_4h_to_h(intermediate_parallel) + + return output + + +class GLMBlock(torch.nn.Module): + + def __init__(self, + hidden_size, + num_attention_heads, + layernorm_epsilon, + layer_id, + inner_hidden_size=None, + hidden_size_per_attention_head=None, + layernorm=LayerNorm, + use_bias=True, + params_dtype=torch.float, + num_layers=28, + position_encoding_2d=True): + super(GLMBlock, self).__init__() + # Set output layer initialization if not provided. + + self.layer_id = layer_id + + # Layernorm on the input data. + self.input_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) + + self.position_encoding_2d = position_encoding_2d + + # Self attention. + self.attention = SelfAttention( + hidden_size, + num_attention_heads, + layer_id, + hidden_size_per_attention_head=hidden_size_per_attention_head, + bias=use_bias, + params_dtype=params_dtype, + position_encoding_2d=self.position_encoding_2d) + + # Layernorm on the input data. + self.post_attention_layernorm = layernorm( + hidden_size, eps=layernorm_epsilon) + + self.num_layers = num_layers + + # GLU + self.mlp = GLU( + hidden_size, + inner_hidden_size=inner_hidden_size, + bias=use_bias, + layer_id=layer_id, + params_dtype=params_dtype, + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids, + attention_mask: torch.Tensor, + layer_id, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + """ + hidden_states: [seq_len, batch, hidden_size] + attention_mask: [(1, 1), seq_len, seq_len] + """ + + # Layer norm at the begining of the transformer layer. + # [seq_len, batch, hidden_size] + attention_input = self.input_layernorm(hidden_states) + + # Self attention. + attention_outputs = self.attention( + attention_input, + position_ids, + attention_mask=attention_mask, + layer_id=layer_id, + layer_past=layer_past, + use_cache=use_cache, + output_attentions=output_attentions) + + attention_output = attention_outputs[0] + + outputs = attention_outputs[1:] + + # Residual connection. + alpha = (2 * self.num_layers)**0.5 + hidden_states = attention_input * alpha + attention_output + + mlp_input = self.post_attention_layernorm(hidden_states) + + # MLP. + mlp_output = self.mlp(mlp_input) + + # Second residual connection. + output = mlp_input * alpha + mlp_output + + if use_cache: + outputs = (output, ) + outputs + else: + outputs = (output, ) + outputs[1:] + + return outputs # hidden_states, present, attentions + + +class ChatGLMPreTrainedModel(TorchModel, PreTrainedModel): + """ + An abstract class to handle weights initialization and + a simple interface for downloading and loading pretrained models. + """ + + is_parallelizable = False + supports_gradient_checkpointing = True + config_class = ChatGLMConfig + base_model_prefix = 'transformer' + _no_split_modules = ['GLMBlock'] + + def __init__(self, config, **kwargs): + super().__init__(config.name_or_path, **kwargs) + super(Model, self).__init__(config) + + def _init_weights(self, module: nn.Module): + """Initialize the weights.""" + return + + def get_masks(self, input_ids, device): + batch_size, seq_length = input_ids.shape + context_lengths = [ + seq.tolist().index(self.config.bos_token_id) for seq in input_ids + ] + attention_mask = torch.ones((batch_size, seq_length, seq_length), + device=device) + attention_mask.tril_() + for i, context_length in enumerate(context_lengths): + attention_mask[i, :, :context_length] = 1 + attention_mask.unsqueeze_(1) + attention_mask = (attention_mask < 0.5).bool() + + return attention_mask + + def get_position_ids(self, input_ids, mask_positions, device, gmask=False): + batch_size, seq_length = input_ids.shape + context_lengths = [ + seq.tolist().index(self.config.bos_token_id) for seq in input_ids + ] + if self.position_encoding_2d: + position_ids = torch.arange( + seq_length, dtype=torch.long, + device=device).unsqueeze(0).repeat(batch_size, 1) + for i, context_length in enumerate(context_lengths): + position_ids[i, context_length:] = mask_positions[i] + block_position_ids = [ + torch.cat(( + torch.zeros( # noqa + context_length, + dtype=torch.long, + device=device), # noqa + torch.arange( # noqa + seq_length - context_length, # noqa + dtype=torch.long, # noqa + device=device) + 1)) # noqa + for context_length in context_lengths + ] + block_position_ids = torch.stack(block_position_ids, dim=0) + position_ids = torch.stack((position_ids, block_position_ids), + dim=1) + else: + position_ids = torch.arange( + seq_length, dtype=torch.long, + device=device).unsqueeze(0).repeat(batch_size, 1) + if not gmask: + for i, context_length in enumerate(context_lengths): + position_ids[context_length:] = mask_positions[i] + + return position_ids + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, ChatGLMModel): + module.gradient_checkpointing = value + + @classmethod + def _instantiate(cls, **kwargs): + """Instantiate the model. + + Args: + kwargs: Input args. + model_dir: The model dir used to load the checkpoint and the label information. + + Returns: + The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained + """ + + model_dir = kwargs.pop('model_dir', None) + kwargs.pop('cfg', None) + model = super(Model, cls).from_pretrained( + pretrained_model_name_or_path=model_dir, **kwargs) + model.model_dir = model_dir + return model + + +CHATGLM_6B_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general + usage and behavior. + + Parameters: + config ([`~ChatGLM6BConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the configuration. + Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +CHATGLM_6B_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`ChatGLM6BTokenizer`]. + See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. + Selected in the range `[0, config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert *input_ids* indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + 'The bare ChatGLM-6B Model transformer outputting raw hidden-states without any specific head on top.', + CHATGLM_6B_START_DOCSTRING, +) +class ChatGLMModel(ChatGLMPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well + as a decoder, in which case a layer of cross-attention is added between + the self-attention layers, following the architecture described in [Attention is + all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, + Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the + `is_decoder` argument of the configuration set to `True`. + To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` + argument and `add_cross_attention` set to `True`; an + `encoder_hidden_states` is then expected as an input to the forward pass. + """ + + def __init__(self, config: ChatGLMConfig): + super().__init__(config) + + # recording parameters + self.max_sequence_length = config.max_sequence_length + self.hidden_size = config.hidden_size + self.params_dtype = torch.half + self.num_attention_heads = config.num_attention_heads + self.vocab_size = config.vocab_size + self.num_layers = config.num_layers + self.layernorm_epsilon = config.layernorm_epsilon + self.inner_hidden_size = config.inner_hidden_size + self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads + self.position_encoding_2d = config.position_encoding_2d + self.pre_seq_len = config.pre_seq_len + self.prefix_projection = config.prefix_projection + + self.word_embeddings = skip_init( + torch.nn.Embedding, + num_embeddings=self.vocab_size, + embedding_dim=self.hidden_size, + dtype=self.params_dtype) + self.gradient_checkpointing = False + + def get_layer(layer_id): + return GLMBlock( + self.hidden_size, + self.num_attention_heads, + self.layernorm_epsilon, + layer_id, + inner_hidden_size=self.inner_hidden_size, + hidden_size_per_attention_head=self. + hidden_size_per_attention_head, + layernorm=LayerNorm, + use_bias=True, + params_dtype=self.params_dtype, + position_encoding_2d=self.position_encoding_2d, + ) + + self.layers = torch.nn.ModuleList( + [get_layer(layer_id) for layer_id in range(self.num_layers)]) + + # Final layer norm before output. + self.final_layernorm = LayerNorm( + self.hidden_size, eps=self.layernorm_epsilon) + + if self.pre_seq_len is not None: + for param in self.parameters(): + param.requires_grad = False + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = PrefixEncoder(config) + self.dropout = torch.nn.Dropout(0.1) + + # total_params = sum(p.numel() for p in self.parameters()) + # trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) + # print("Using p-tuning v2: # trainable_params = {} / {}".format(trainable_params, total_params)) + + def get_input_embeddings(self): + return self.word_embeddings + + def set_input_embeddings(self, new_embeddings: torch.Tensor): + self.word_embeddings = new_embeddings + + def get_prompt(self, batch_size, device, dtype=torch.half): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, + -1).to(device) + past_key_values = self.prefix_encoder(prefix_tokens).type(dtype) + past_key_values = past_key_values.view( + batch_size, self.pre_seq_len, self.num_layers * 2, + self.num_attention_heads, + self.hidden_size // self.num_attention_heads) + # seq_len, b, nh, hidden_size + print('#########################:', past_key_values.device) + past_key_values = self.dropout(past_key_values) + past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) + # past_key_values = [(v[0], v[1]) for v in past_key_values] + return past_key_values + + @add_start_docstrings_to_model_forward( + CHATGLM_6B_INPUTS_DOCSTRING.format('batch_size, sequence_length')) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPastAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], + ...]] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPast]: + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.gradient_checkpointing and self.training: + if use_cache: + # logger.warning_once( + # "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + # ) + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + 'You cannot specify both input_ids and inputs_embeds at the same time' + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape[:2] + else: + raise ValueError( + 'You have to specify either input_ids or inputs_embeds') + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + if past_key_values is None: + if self.pre_seq_len is not None: + past_key_values = self.get_prompt( + batch_size=input_ids.shape[0], + device=input_ids.device, + dtype=inputs_embeds.dtype) + else: + past_key_values = tuple([None] * len(self.layers)) + + if attention_mask is None: + attention_mask = self.get_masks( + input_ids, device=input_ids.device) + + if position_ids is None: + MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id + mask_token = gMASK if gMASK in input_ids else MASK + use_gmask = True if gMASK in input_ids else False + + mask_positions = [ + seq.tolist().index(mask_token) for seq in input_ids + ] + position_ids = self.get_position_ids( + input_ids, + mask_positions=mask_positions, + device=input_ids.device, + gmask=use_gmask) + + if self.pre_seq_len is not None and attention_mask is not None: + prefix_attention_mask = torch.ones( + batch_size, 1, input_ids.size(-1), + self.pre_seq_len).to(attention_mask.device) + prefix_attention_mask = (prefix_attention_mask < 0.5).bool() + attention_mask = torch.cat((prefix_attention_mask, attention_mask), + dim=3) + + # [seq_len, batch, hidden_size] + hidden_states = inputs_embeds.transpose(0, 1) + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + if attention_mask is None: + attention_mask = torch.zeros(1, 1, device=input_ids.device).bool() + + else: + attention_mask = attention_mask.to(input_ids.device) + + for i, layer in enumerate(self.layers): + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + layer_past = past_key_values[i] + + if self.gradient_checkpointing and self.training: + layer_ret = torch.utils.checkpoint.checkpoint( + layer, hidden_states, position_ids, attention_mask, + torch.tensor(i), layer_past, use_cache, output_attentions) + else: + layer_ret = layer( + hidden_states, + position_ids=position_ids, + attention_mask=attention_mask, + layer_id=torch.tensor(i), + layer_past=layer_past, + use_cache=use_cache, + output_attentions=output_attentions) + + hidden_states = layer_ret[0] + + if use_cache: + presents = presents + (layer_ret[1], ) + + if output_attentions: + all_self_attentions = all_self_attentions + ( + layer_ret[2 if use_cache else 1], ) + + # Final layer norm. + hidden_states = self.final_layernorm(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + if not return_dict: + return tuple(v for v in [ + hidden_states, presents, all_hidden_states, all_self_attentions + ] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +@MODELS.register_module(Tasks.chat, module_name=Models.chatglm_6b) +class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): + + def __init__(self, config: ChatGLMConfig): + super().__init__(config) + + # self.hidden_size = config.hidden_size + # self.params_dtype = torch.half + # self.vocab_size = config.vocab_size + self.max_sequence_length = config.max_sequence_length + + self.position_encoding_2d = config.position_encoding_2d + + self.transformer = ChatGLMModel(config) + + self.lm_head = skip_init( + nn.Linear, + config.hidden_size, + config.vocab_size, + bias=False, + dtype=torch.half) + + self.config = config + + self.quantized = False + + if self.config.quantization_bit: + self.quantize(self.config.quantization_bit, empty_init=True) + + # loading tokenizer + self.tokenizer = ChatGLMTokenizer.from_pretrained(config.name_or_path) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + standardize_cache_format: bool = False, + ) -> Dict[str, Any]: + # update past_key_values + model_kwargs['past_key_values'] = self._extract_past_from_model_output( + outputs, standardize_cache_format=standardize_cache_format) + + # update attention mask + if 'attention_mask' in model_kwargs: + attention_mask = model_kwargs['attention_mask'] + if attention_mask is not None and attention_mask.dtype == torch.bool: + attention_mask = torch.cat([ + attention_mask, + attention_mask.new_ones((*attention_mask.shape[:3], 1)) + ], + dim=3) # noqa + new_attention_mask = attention_mask[:, :, -1:].clone() + new_attention_mask[..., -1] = False + model_kwargs['attention_mask'] = torch.cat( + [attention_mask, new_attention_mask], dim=2) + + # update position ids + if 'position_ids' in model_kwargs: + position_ids = model_kwargs['position_ids'] + new_position_id = position_ids[..., -1:].clone() + new_position_id[:, 1, :] += 1 + model_kwargs['position_ids'] = torch.cat( + [position_ids, new_position_id], dim=-1) + + return model_kwargs + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past: Optional[torch.Tensor] = None, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + **kwargs) -> dict: + batch_size, seq_length = input_ids.shape + MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id + mask_token = gMASK if gMASK in input_ids else MASK + use_gmask = True if gMASK in input_ids else False + seqs = input_ids.tolist() + mask_positions = [seq.index(mask_token) for seq in seqs] + + # only last token for input_ids if past is not None + if past is not None or past_key_values is not None: + last_token = input_ids[:, -1].unsqueeze(-1) + if attention_mask is not None and attention_mask.dtype == torch.bool: + attention_mask = attention_mask[:, :, -1:] + else: + attention_mask = None + if position_ids is not None: + position_ids = position_ids[..., -1:] + else: + context_lengths = [ + seq.index(self.config.bos_token_id) for seq in seqs + ] + if self.position_encoding_2d: + position_ids = torch.tensor( + [[mask_position, seq_length - context_length] + for mask_position, context_length in zip( + mask_positions, context_lengths)], + dtype=torch.long, + device=input_ids.device).unsqueeze(-1) + else: + position_ids = torch.tensor( + [mask_position for mask_position in mask_positions], + dtype=torch.long, + device=input_ids.device).unsqueeze(-1) + + if past is None: + past = past_key_values + return { + 'input_ids': last_token, + 'past_key_values': past, + 'position_ids': position_ids, + 'attention_mask': attention_mask + } + else: + if attention_mask is not None and attention_mask.dtype != torch.bool: + # logger.warning_once(f"The dtype of attention mask ({attention_mask.dtype}) is not bool") + attention_mask = None + if attention_mask is None: + attention_mask = self.get_masks( + input_ids, device=input_ids.device) + if position_ids is None: + position_ids = self.get_position_ids( + input_ids, + device=input_ids.device, + mask_positions=mask_positions, + gmask=use_gmask) + + return { + 'input_ids': input_ids, + 'past_key_values': past, + 'position_ids': position_ids, + 'attention_mask': attention_mask + } + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + + lm_logits = self.lm_head(hidden_states).permute(1, 0, 2).contiguous() + + loss = None + if labels is not None: + lm_logits = lm_logits.to(torch.float32) + + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1)) + + lm_logits = lm_logits.to(hidden_states.dtype) + loss = loss.to(hidden_states.dtype) + + if not return_dict: + output = (lm_logits, ) + transformer_outputs[1:] + return ((loss, ) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache( + past: Tuple[Tuple[torch.Tensor, torch.Tensor], + ...], beam_idx: torch.LongTensor + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + + Output shares the same memory storage as `past`. + """ + return tuple(( + layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), + layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), + ) for layer_past in past) + + def process_response(self, response): + response = response.strip() + response = response.replace('[[训练时间]]', '2023年') + punkts = [ + [',', ','], + ['!', '!'], + [':', ':'], + [';', ';'], + ['\?', '?'], # noqa + ] + for item in punkts: + response = re.sub(r'([\u4e00-\u9fff])%s' % item[0], + r'\1%s' % item[1], response) + response = re.sub(r'%s([\u4e00-\u9fff])' % item[0], + r'%s\1' % item[1], response) + return response + + @torch.no_grad() + def _chat(self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = None, + max_length: int = 2048, + num_beams=1, + do_sample=True, + top_p=0.7, + temperature=0.95, + logits_processor=None, + **kwargs): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + gen_kwargs = { + 'max_length': max_length, + 'num_beams': num_beams, + 'do_sample': do_sample, + 'top_p': top_p, + 'temperature': temperature, + 'logits_processor': logits_processor, + **kwargs + } + if not history: + prompt = query + else: + prompt = '' + for i, (old_query, response) in enumerate(history): + prompt += '[Round {}]\n问:{}\n答:{}\n'.format( + i, old_query, response) + prompt += '[Round {}]\n问:{}\n答:'.format(len(history), query) + inputs = tokenizer([prompt], return_tensors='pt') + inputs = inputs.to(self.device) + outputs = self.generate(**inputs, **gen_kwargs) + outputs = outputs.tolist()[0][len(inputs['input_ids'][0]):] + response = tokenizer.decode(outputs) + response = self.process_response(response) + history = history + [(query, response)] + return response, history + + @torch.no_grad() + def stream_chat(self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = None, + max_length: int = 2048, + do_sample=True, + top_p=0.7, + temperature=0.95, + logits_processor=None, + **kwargs): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + gen_kwargs = { + 'max_length': max_length, + 'do_sample': do_sample, + 'top_p': top_p, + 'temperature': temperature, + 'logits_processor': logits_processor, + **kwargs + } + if not history: + prompt = query + else: + prompt = '' + for i, (old_query, response) in enumerate(history): + prompt += '[Round {}]\n问:{}\n答:{}\n'.format( + i, old_query, response) + prompt += '[Round {}]\n问:{}\n答:'.format(len(history), query) + inputs = tokenizer([prompt], return_tensors='pt') + inputs = inputs.to(self.device) + for outputs in self.stream_generate(**inputs, **gen_kwargs): + outputs = outputs.tolist()[0][len(inputs['input_ids'][0]):] + response = tokenizer.decode(outputs) + response = self.process_response(response) + new_history = history + [(query, response)] + yield response, new_history + + @torch.no_grad() + def stream_generate( + self, + input_ids, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], + List[int]]] = None, + **kwargs, + ): + _, input_ids_seq_length = input_ids.shape[0], input_ids.shape[ + -1] # noqa + + if generation_config is None: + generation_config = self.generation_config + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) + _, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id + + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + + has_default_max_length = kwargs.get( + 'max_length') is None and generation_config.max_length is not None + if has_default_max_length and generation_config.max_new_tokens is None: + warnings.warn( + f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " + 'This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we' + ' recommend using `max_new_tokens` to control the maximum length of the generation.', + UserWarning, + ) + elif generation_config.max_new_tokens is not None: + generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length + if not has_default_max_length: + logger.warn( + f'Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=' + f'{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. ' + 'Please refer to the documentation for more information. ' + '(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)', + UserWarning, + ) + + if input_ids_seq_length >= generation_config.max_length: + input_ids_string = 'decoder_input_ids' if self.config.is_encoder_decoder else 'input_ids' + logger.warning( + f'Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to' + f' {generation_config.max_length}. This can lead to unexpected behavior. You should consider' + ' increasing `max_new_tokens`.') + + # 2. Set generation parameters if not already defined + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList( + ) + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList( + ) + + logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=input_ids, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + ) + + stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, + stopping_criteria=stopping_criteria) + logits_warper = self._get_logits_warper(generation_config) + + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + scores = None + while True: + model_inputs = self.prepare_inputs_for_generation( + input_ids, **model_kwargs) + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + if generation_config.do_sample: + next_tokens = torch.multinomial( + probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(probs, dim=-1) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder) + unfinished_sequences = unfinished_sequences.mul( + (sum(next_tokens != i for i in eos_token_id)).long()) + + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria( + input_ids, scores): + break + yield input_ids + + def quantize(self, bits: int, empty_init=False, **kwargs): + if bits == 0: + return + + from .quantization import quantize + + if self.quantized: + logger.info('Already quantized.') + return self + + self.quantized = True + + self.config.quantization_bit = bits + + self.transformer = quantize( + self.transformer, bits, empty_init=empty_init, **kwargs) + return self + + def chat(self, input: Dict) -> Dict: + text = input['text'] + history = input['history'] + # args + if 'max_length' in input: + max_length = input['max_length'] + else: + max_length = 2048 + + if 'temperature' in input: + temperature = input['temperature'] + else: + temperature = 0.95 + + if 'num_beams' in input: + num_beams = input['num_beams'] + else: + num_beams = 1 + + if 'do_sample' in input: + do_sample = input['do_sample'] + else: + do_sample = True + + if type(history) == torch.Tensor: + history = history.tolist() + response, history = self._chat( + self.tokenizer, + text, + history, + max_length=max_length, + temperature=temperature, + num_beams=num_beams, + do_sample=do_sample) + logger.info('Generation finished.') + return {OutputKeys.RESPONSE: response, OutputKeys.HISTORY: history} diff --git a/modelscope/models/nlp/chatglm/tokenization.py b/modelscope/models/nlp/chatglm/tokenization.py new file mode 100644 index 00000000..77bcde55 --- /dev/null +++ b/modelscope/models/nlp/chatglm/tokenization.py @@ -0,0 +1,463 @@ +"""Tokenization classes for ChatGLM.""" +import os +from typing import Dict, List, Optional, Union + +import numpy as np +import sentencepiece as spm +from transformers.tokenization_utils import PreTrainedTokenizer +from transformers.tokenization_utils_base import BatchEncoding, EncodedInput +from transformers.utils import PaddingStrategy, logging + +logger = logging.get_logger(__name__) + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + 'THUDM/chatglm-6b': 2048, +} + + +class TextTokenizer: + + def __init__(self, model_path): + self.sp = spm.SentencePieceProcessor() + self.sp.Load(model_path) + self.num_tokens = self.sp.vocab_size() + + def encode(self, text): + return self.sp.EncodeAsIds(text) + + def decode(self, ids: List[int]): + return self.sp.DecodeIds(ids) + + def tokenize(self, text): + return self.sp.EncodeAsPieces(text) + + def convert_tokens_to_ids(self, tokens): + return [self.sp.PieceToId(token) for token in tokens] + + def convert_token_to_id(self, token): + return self.sp.PieceToId(token) + + def convert_id_to_token(self, idx): + return self.sp.IdToPiece(idx) + + def __len__(self): + return self.num_tokens + + +class SPTokenizer: + + def __init__( + self, + vocab_file, + num_image_tokens=20000, + max_blank_length=80, + byte_fallback=True, + ): + assert vocab_file is not None + self.vocab_file = vocab_file + self.num_image_tokens = num_image_tokens + self.special_tokens = [ + '[MASK]', '[gMASK]', '[sMASK]', '', '', '', + '', '' + ] + self.max_blank_length = max_blank_length + self.byte_fallback = byte_fallback + self.text_tokenizer = TextTokenizer(vocab_file) + + def _get_text_tokenizer(self): + return self.text_tokenizer + + @staticmethod + def get_blank_token(length: int): + assert length >= 2 + return f'<|blank_{length}|>' + + @staticmethod + def get_tab_token(): + return '<|tab|>' + + @property + def num_text_tokens(self): + return self.text_tokenizer.num_tokens + + @property + def num_tokens(self): + return self.num_image_tokens + self.num_text_tokens + + @staticmethod + def _encode_whitespaces(text: str, max_len: int = 80): + text = text.replace('\t', SPTokenizer.get_tab_token()) + for i in range(max_len, 1, -1): + text = text.replace(' ' * i, SPTokenizer.get_blank_token(i)) + return text + + def _preprocess(self, text: str, linebreak=True, whitespaces=True): + if linebreak: + text = text.replace('\n', '') + if whitespaces: + text = self._encode_whitespaces( + text, max_len=self.max_blank_length) + return text + + def encode(self, + text: str, + linebreak=True, + whitespaces=True, + add_dummy_prefix=True) -> List[int]: + """ + @param text: Text to encode. + @param linebreak: Whether to encode newline (\n) in text. + @param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding. + @param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text. + @param add_dummy_prefix: Whether to add dummy blank space in the beginning. + """ + text = self._preprocess(text, linebreak, whitespaces) + if not add_dummy_prefix: + text = '' + text + tmp = self._get_text_tokenizer().encode(text) + tokens = [x + self.num_image_tokens for x in tmp] + return tokens if add_dummy_prefix else tokens[2:] + + def decode(self, text_ids: List[int]) -> str: + ids = [int(_id) - self.num_image_tokens for _id in text_ids] + ids = [_id for _id in ids if _id >= 0] + text = self._get_text_tokenizer().decode(ids) + text = text.replace('', '\n') + text = text.replace(SPTokenizer.get_tab_token(), '\t') + for i in range(2, self.max_blank_length + 1): + text = text.replace(self.get_blank_token(i), ' ' * i) + return text + + def tokenize(self, + text: str, + linebreak=True, + whitespaces=True, + add_dummy_prefix=True) -> List[str]: + """ + @param text: Text to encode. + @param linebreak: Whether to encode newline (\n) in text. + @param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding. + @param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text. + @param add_dummy_prefix: Whether to add dummy blank space in the beginning. + """ + text = self._preprocess(text, linebreak, whitespaces) + if not add_dummy_prefix: + text = '' + text + tokens = self._get_text_tokenizer().tokenize(text) + return tokens if add_dummy_prefix else tokens[2:] + + def __getitem__(self, x: Union[int, str]): + if isinstance(x, int): + if x < self.num_image_tokens: + return ''.format(x) + else: + return self.text_tokenizer.convert_id_to_token( + x - self.num_image_tokens) + elif isinstance(x, str): + if x.startswith('') and x[7:-1].isdigit(): + return int(x[7:-1]) + else: + return self.text_tokenizer.convert_token_to_id( + x) + self.num_image_tokens + else: + raise ValueError('The key should be str or int.') + + +class ChatGLMTokenizer(PreTrainedTokenizer): + """ + Construct a ChatGLM tokenizer. Based on byte-level Byte-Pair-Encoding. + + Args: + vocab_file: Path to the vocabulary file. + do_lower_case: Use lower case letters. + remove_space: Remove spaces. + bos_token: The bos token + eos_token: The Eos Token + end_token: The end token + mask_token: The mask token + gmask_token: The gmask token + padding_side: The padding side + num_image_tokens: The `num_image_tokens` in `SPTokenizer` + """ + + vocab_files_names = {'vocab_file': 'ice_text.model'} + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + model_input_names = ['input_ids', 'attention_mask', 'position_ids'] + + def __init__(self, + vocab_file, + do_lower_case=False, + remove_space=False, + bos_token='', + eos_token='', + end_token='', + mask_token='[MASK]', + gmask_token='[gMASK]', + padding_side='left', + num_image_tokens=20000, + **kwargs) -> None: + super().__init__( + do_lower_case=do_lower_case, + remove_space=remove_space, + padding_side=padding_side, + bos_token=bos_token, + eos_token=eos_token, + end_token=end_token, + mask_token=mask_token, + gmask_token=gmask_token, + num_image_tokens=num_image_tokens, + **kwargs) + + self.do_lower_case = do_lower_case + self.remove_space = remove_space + self.vocab_file = vocab_file + + self.bos_token = bos_token + self.eos_token = eos_token + self.end_token = end_token + self.mask_token = mask_token + self.gmask_token = gmask_token + + self.sp_tokenizer = SPTokenizer( + vocab_file, num_image_tokens=num_image_tokens) + """ Initialisation """ + + @property + def gmask_token_id(self) -> Optional[int]: + if self.gmask_token is None: + return None + return self.convert_tokens_to_ids(self.gmask_token) + + @property + def end_token_id(self) -> Optional[int]: + """ + `Optional[int]`: Id of the end of context token in the vocabulary. Returns `None` if the token has not been + set. + """ + if self.end_token is None: + return None + return self.convert_tokens_to_ids(self.end_token) + + @property + def vocab_size(self): + """ Returns vocab size """ + return self.sp_tokenizer.num_tokens + + def get_vocab(self): + """ Returns vocab as a dict """ + vocab = { + self._convert_id_to_token(i): i + for i in range(self.vocab_size) + } + vocab.update(self.added_tokens_encoder) + return vocab + + def preprocess_text(self, inputs): + if self.remove_space: + outputs = ' '.join(inputs.strip().split()) + else: + outputs = inputs + + if self.do_lower_case: + outputs = outputs.lower() + + return outputs + + def _tokenize(self, text, **kwargs): + """ Returns a tokenized string. """ + text = self.preprocess_text(text) + + seq = self.sp_tokenizer.tokenize(text) + + return seq + + def _decode(self, + token_ids: Union[int, List[int]], + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: bool = True, + **kwargs) -> str: + if isinstance(token_ids, int): + token_ids = [token_ids] + if len(token_ids) == 0: + return '' + if self.pad_token_id in token_ids: # remove pad + token_ids = list(filter((self.pad_token_id).__ne__, token_ids)) + return self.sp_tokenizer.decode(token_ids) + + def _convert_token_to_id(self, token): + """ Converts a token (str) in an id using the vocab. """ + return self.sp_tokenizer[token] + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.sp_tokenizer[index] + + def save_vocabulary(self, save_directory, filename_prefix=None): + """ + Save the vocabulary and special tokens file to a directory. + + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + filename_prefix (`str`, *optional*): + An optional prefix to add to the named of the saved files. + + Returns: + `Tuple(str)`: Paths to the files saved. + """ + if os.path.isdir(save_directory): + vocab_file = os.path.join(save_directory, + self.vocab_files_names['vocab_file']) + else: + vocab_file = save_directory + + with open(self.vocab_file, 'rb') as fin: + proto_str = fin.read() + + with open(vocab_file, 'wb') as writer: + writer.write(proto_str) + + return (vocab_file, ) + + def build_inputs_with_special_tokens( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A BERT sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + mask_ids = self.sp_tokenizer[self.mask_token] + gmask_ids = self.sp_tokenizer[self.gmask_token] + eos_id = self.sp_tokenizer[self.eos_token] + if mask_ids not in token_ids_0 and gmask_ids not in token_ids_0: + token_ids_0 += [gmask_ids] + + if token_ids_0[-1] != mask_ids and token_ids_0[-1] != gmask_ids: + token_ids_0 += [self.sp_tokenizer[self.end_token]] + + token_ids_0 += [self.sp_tokenizer[self.bos_token]] + + if token_ids_1 is not None: + if not token_ids_1 or token_ids_1[-1] != eos_id: + token_ids_1 += [eos_id] + token_ids_0 += token_ids_1 + + return token_ids_0 + + def _pad( + self, + encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], + max_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + ) -> dict: + """ + Pad encoded inputs (on left/right and up to predefined length or max length in the batch) + + Args: + encoded_inputs: + Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). + max_length: maximum length of the returned list and optionally padding length (see below). + Will truncate by taking into account the special tokens. + padding_strategy: PaddingStrategy to use for padding. + + - PaddingStrategy.LONGEST Pad to the longest sequence in the batch + - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) + - PaddingStrategy.DO_NOT_PAD: Do not pad + The tokenizer padding sides are defined in self.padding_side: + + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + return_attention_mask: + (optional) Set to False to avoid returning attention mask (default: set to model specifics) + """ + # Load from model defaults + bos_token_id = self.sp_tokenizer[self.bos_token] + mask_token_id = self.sp_tokenizer[self.mask_token] + gmask_token_id = self.sp_tokenizer[self.gmask_token] + assert self.padding_side == 'left' + + required_input = encoded_inputs[self.model_input_names[0]] + seq_length = len(required_input) + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = len(required_input) + + if max_length is not None and pad_to_multiple_of is not None and ( + max_length % pad_to_multiple_of != 0): + max_length = ( + (max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len( + required_input) != max_length + + # Initialize attention mask if not present. + if max_length is not None: + if 'attention_mask' not in encoded_inputs: + if bos_token_id in required_input: + context_length = required_input.index(bos_token_id) + else: + context_length = seq_length + attention_mask = np.ones((1, seq_length, seq_length)) + attention_mask = np.tril(attention_mask) + attention_mask[:, :, :context_length] = 1 + attention_mask = np.bool_(attention_mask < 0.5) + encoded_inputs['attention_mask'] = attention_mask + + if 'position_ids' not in encoded_inputs: + position_ids = np.arange(seq_length, dtype=np.int64) + mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id + if mask_token in required_input: + mask_position = required_input.index(mask_token) + position_ids[context_length:] = mask_position + block_position_ids = np.concatenate([ + np.zeros(context_length, dtype=np.int64), + np.arange( + 1, seq_length - context_length + 1, dtype=np.int64) + ]) + encoded_inputs['position_ids'] = np.stack( + [position_ids, block_position_ids], axis=0) + + if needs_to_be_padded: + difference = max_length - len(required_input) + + if 'attention_mask' in encoded_inputs: + encoded_inputs['attention_mask'] = np.pad( + encoded_inputs['attention_mask'], + pad_width=[(0, 0), (difference, 0), (difference, 0)], + mode='constant', + constant_values=True) + if 'token_type_ids' in encoded_inputs: + encoded_inputs['token_type_ids'] = [ + self.pad_token_type_id + ] * difference + encoded_inputs['token_type_ids'] + if 'special_tokens_mask' in encoded_inputs: + encoded_inputs['special_tokens_mask'] = [ + 1 + ] * difference + encoded_inputs['special_tokens_mask'] + if 'position_ids' in encoded_inputs: + encoded_inputs['position_ids'] = np.pad( + encoded_inputs['position_ids'], + pad_width=[(0, 0), (difference, 0)]) + encoded_inputs[self.model_input_names[ + 0]] = [self.pad_token_id] * difference + required_input + + return encoded_inputs diff --git a/modelscope/models/nlp/chatglm2/__init__.py b/modelscope/models/nlp/chatglm2/__init__.py new file mode 100644 index 00000000..a2b5bfea --- /dev/null +++ b/modelscope/models/nlp/chatglm2/__init__.py @@ -0,0 +1,46 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .configuration import ChatGLM2Config + from .tokenization import ChatGLM2Tokenizer + from .text_generation import ChatGLM2ForConditionalGeneration + from .quantization import ( + quantize, ) + +else: + _import_structure = { + 'configuration': ['ChatGLM2Config'], + 'text_generation': ['ChatGLM2ForConditionalGeneration'], + 'quantization': ['quantize'], + 'tokenization': [ + 'ChatGLM2Tokenizer', + ], + } + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__) diff --git a/modelscope/models/nlp/chatglm2/configuration.py b/modelscope/models/nlp/chatglm2/configuration.py new file mode 100644 index 00000000..b10db870 --- /dev/null +++ b/modelscope/models/nlp/chatglm2/configuration.py @@ -0,0 +1,58 @@ +""" ChatGLM model configuration """ + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class ChatGLM2Config(PretrainedConfig): + + def __init__(self, + num_layers=28, + padded_vocab_size=65024, + hidden_size=4096, + ffn_hidden_size=13696, + kv_channels=128, + num_attention_heads=32, + seq_length=2048, + hidden_dropout=0.0, + attention_dropout=0.0, + layernorm_epsilon=1e-5, + rmsnorm=True, + apply_residual_connection_post_layernorm=False, + post_layer_norm=True, + add_bias_linear=False, + add_qkv_bias=False, + interleaved_qkv=False, + bias_dropout_fusion=True, + multi_query_attention=False, + multi_query_group_num=1, + apply_query_key_layer_scaling=True, + attention_softmax_in_fp32=True, + fp32_residual_connection=False, + quantization_bit=0, + **kwargs): + self.num_layers = num_layers + self.padded_vocab_size = padded_vocab_size + self.hidden_size = hidden_size + self.ffn_hidden_size = ffn_hidden_size + self.kv_channels = kv_channels + self.num_attention_heads = num_attention_heads + self.seq_length = seq_length + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.layernorm_epsilon = layernorm_epsilon + self.rmsnorm = rmsnorm + self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm + self.post_layer_norm = post_layer_norm + self.add_bias_linear = add_bias_linear + self.add_qkv_bias = add_qkv_bias + self.bias_dropout_fusion = bias_dropout_fusion + self.multi_query_attention = multi_query_attention + self.multi_query_group_num = multi_query_group_num + self.apply_query_key_layer_scaling = apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = attention_softmax_in_fp32 + self.fp32_residual_connection = fp32_residual_connection + self.quantization_bit = quantization_bit + super().__init__(**kwargs) diff --git a/modelscope/models/nlp/chatglm2/quantization.py b/modelscope/models/nlp/chatglm2/quantization.py new file mode 100644 index 00000000..612c9e4b --- /dev/null +++ b/modelscope/models/nlp/chatglm2/quantization.py @@ -0,0 +1,223 @@ +import base64 +import bz2 +import ctypes +from functools import partial +from typing import List + +import torch +from torch.nn import Linear +from torch.nn.parameter import Parameter +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +try: + from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up + + class Kernel: + + def __init__(self, code: bytes, function_names: List[str]): + self.code = code + self._function_names = function_names + self._cmodule = LazyKernelCModule(self.code) + + for name in self._function_names: + setattr(self, name, KernelFunction(self._cmodule, name)) + + quantization_code = '$QlpoOTFBWSZTWU9yuJUAQHN//////////f/n/8/n///n//bt4dTidcVx8X3V9FV/92/v4B7/AD5FBQFAAAChSgKpFCFAFVSigUAAAEKhSgUUqgFBKigqVREQAABQBQIANDTTIGI00BkZBkNGE0A0BkBkGQGRkaNAaAGQNBoGgDIAAYIGTI0DQAQAaGmmQMRpoDIyDIaMJoBoDIDIMgMjI0aA0AMgaDQNAGQAAwQMmRoGgAgA0NNMgYjTQGRkGQ0YTQDQGQGQZAZGRo0BoAZA0GgaAMgABggZMjQNABABoaaZAxGmgMjIMhowmgGgMgMgyAyMjRoDQAyBoNA0AZAADBAyZGgaAAmqU1NEgJqnptU/Sn4jRR6J6epk2pqb1Q/SgAPUGgyNNGjQ2SBpoAZAAGg0NB6mgDIAAAAA2oaApSREBNAARhGiYEaEwU8pvImlP0k2aam1GaGqbFNM1MHpTwmkepmyU9R6nqPKekHqNNPUxNGhp6n6p6QaZ6o9TG1GMqcoV9ly6nRanHlq6zPNbnGZNi6HSug+2nPiZ13XcnFYZW+45W11CumhzYhchOJ2GLLV1OBjBjGf4TptOddTSOcVxhqYZMYwZXZZY00zI1paX5X9J+b+f4e+x43RXSxXPOdquiGpduatGyXneN696M9t4HU2eR5XX/kPhP261NTx3JO1Ow7LyuDmeo9a7d351T1ZxnvnrvYnrXv/hXxPCeuYx2XsNmO003eg9J3Z6U7b23meJ4ri01OdzTk9BNO96brz+qT5nuvvH3ds/G+m/JcG/F2XYuhXlvO+jP7U3XgrzPN/lr8Sf1n6j4j7jZs+s/T0tNaNNYzTs12rxjwztHlnire3Nzc3N1wuBwOBwXBvZfoHpD7rFmR99V5vj3aXza3xdBbXMalubTg/jIv5dfAi54Pdc75j4z412n3Npj3Ld/ENm7a3b/Cod6h/ret1/5vn/C+l+gdslMvgPSLJ8d8q+U66fevYn/tW1chleEtNTGlcHCbLRlq0tHzF5tsbbZZfHjjLgZu42XCuC3NrdjTasZGNzgxPIrGqp7r3p7L2p5XjnpPSmTd5XtzqnB6U87zzg1Ol0zd0zsLszxR6lkxp35u6/teL0L0W922cR7Lu1lpL9CsHirzuM2T+BgsyViT6LHcm0/Vr6U/7LGGyJeqTEjt0PHWhF5mCT7R9mtlDwriYv0Tyr/OxYt6qp5r0mPVT0608TqnqMZaarU2nFwrTzzlrs1ed7z1ux60wyr4ydCaTi3enW8x68x0zU7tXSlcmPSW1mGpWJMg4zmPC2lK96tp0OE80y4MfEvnZj8zGluR6b22ki1Ou9V2nCd9xovcPvcYMZYy0lvN60ScZ45vN6yeCeeXFb1lVjnnCar5fwXwE2bzJ4HI1XVPXfXZMm44GUsMpYsmLB65TuVdm0cl0b+i/wGNN66XjeV7zuPpHcnK/juhhjdfId5jMdE5nN0dGmmm2zZs2cexD5n9p/dY352XsvXHaZNWWsmmS1atjR452nYudzvqv2HMRyvNNnlMcDl3R2+yx2uVrBubTW9icHDVtbNXlZm7jma1rM4VurZZd2y6nUau7ZXZ7bVU+mnoOVxZGMrVmvX60605JwmzGZhhhjTWtaaaMaaGTGmNMZasY0iX8VMUl8eepaIrzGSpemWOQyZORk2bNpjUybMmxqYmknCGCFynutfksaZpjTNMaaatM0xsxcGR0sociNqxNSmhhR1ZJPbsn8qyF0t2qH6iYBclclalbtTTcHTDsPaX6rlnElph2Jyumumtynv2Kk8GI7rsvXbIcJgHJOSaSXnnGaI3m87RtVXJOZ/YtgdTE6Wpha6ZlE8ayXkef1fh602r2WwvfMXtMdLlkfnLFdYYwYso+bWqm7yJqHXZGw2nrS5ZanSYnWlxBxMF1V940K2wdrI7R6OYf7DGGamMmTSbRhlS45xmVOumF1EyPCmHrrN8wwZOOrdNtLeMtzFzDlWnfTBxMk2NaXIZHBYxYLD4w8yju0ao65Vz1OIXoS9dLanwCe1PWrYuWMqf1if1z2k2yYfKJ741PDgno1ZQ8DRqvUny3mNoWTzGO6m1DkrJI8JiR5cSd+vZdGOO8nrMoc5+NDUFsMSXaZJeNlMmGLtJsovOsUp7I9S5VojKxF6bTVEelXqlfJobQr3LozSh2Jk7VcrVMfhXqszGWMzNqGhqZY0OadxkyyMssKugZR0KNFXBHlqwmJgTE/BNVMk6ItJXZMR0H47GpXv/DMOvNkmVuaV1PRfEdxuqc7Hcd+ZV/zTLaRxWk0nl9CdCeM6mn5rstHIBcpiuwmUZXeq81DacHI2rmrZ5SuE5mOZd6LQrZg9mx32TprA8BMo5jKN6yLTCi3WzQaZSuhzTtM1fUTGVpG8Tw+KXI0tjEpiWxtLYynOlktSbVlaI5kxP8TDH8kx50xoxi5KcA4pcja8KWLRlO/Ks6q06ergnvm1ca3Tq8Uw7LTUsmWyctXPWmpitl/uvGcWTGXGuAXDfhqazGmjkxcJW5hMMMMpYsXl2TZYtVOddG3XCarUt6Ptq9CZXSNzyuRzqRZOjsxdBbFVz6OA5HI43r1jityVlVpVkxmOsyaYWE1NTGq1sOVh36mHMcxtSvcy70edG0ZGR3I1Go1GRlV7mWWo1G0ZGRqlvH40l7o4m5xMWLLLYyNjnqc8556mdPqLJ31n/1nWOncxzG1tizrHs/Z+d2vP/B/l8wdJ6rHUn2nbbDq4p6htFtYzMMMTaZis1K5GKzGNmxhmUx2DDlZ/qNnIx41xnaMfCZWYaZWtNLTNW8ND4Fw1MyZOCdM428suKG1ehW8TesOydg7J+YYcD4cYR+8dFK6M4E3HM9ZfRNNL+Sn6rsl4DsrDl2HpPCnfxjGXtbZtYys1ttlyJ4T+BvexjGWRjMszK4Jpc77D3GyuVD7q0+G8m9G+2+rGm7cOR2y7FdtY2XUYx/oNlfRYxhMYyYZkyyg55enna9Kt/FFi6GMMwYwdwxWgxGMLKYmUyGExTKMZkMFhkymKuh0NOBNnBu+23LdwDoZYYzGGMxtORaTU1pjTGWTTGGtMrNWUsyyTTLLG1qy2ZjbK2DBllWqxMtBMaYZQmcE7zvvRcTkclUwdkxTaSdyySt/7fpL+T1v516Ji97fwr5JbLu305zMn5+GMTTZ9F+y7ExwmGVfG44yxn3dLv6l5i+Wth1jCrDq21nW9LqvvDzz3Vf3LLH/O/32TJ/erx3bXftO4eF+G956D952K/An4NfvOpjFjExjevP/UmE0fIoZXx6/w6lX/no3D0bLt+ixjieBM6ksRd0yB4Lt2SwYNE+gd1detlZWUnpiZfGfFaK+4PyCa/v18V8X75pe9fLXzp7l3VjF76vWZmHwGz1IZNWT7b8yddJ4q5kyrVdfru6atWc7bVYztL9Jf4GXvT+Y8m9/YsXP6H018a8D4XVOqvfzqeR+6yZOD8dPv0+U7/q5Pl+2dNb0MjzGVH5p6MNQ7cOWvw62U9aHE8DprDek+McLyvDz+te+9Zhq5+YTruufMcWMabqysTmZVWjKPfnK0wyVcrsuhjZRdLkHNvD72b9abriOSGIxiLixMOoalNPXzy+wT/tf+U6HHONfsz+xe8ufHBdQWWGWLA9if0rsnmrxK5LvRZQeWsTCsrmOYy8VteVfuRfcVTtDLItLIsMYxZLdU/DbtSemxF6Z6Zo5WBXE4tFdCyVMMXMTEMZXVlS6Xec2T4e0tHsRcEuWshcJ2YsNF5rUx1E8ifCq6Z+ZP7qdCeu/aTwFd53l16/o0NOw6O3dLavP4Hbi4RdmuDk6DoYaninC0+o4uZjbJ7Rxeu0/FbuFg+q7DVS6fQe0rZ6NDGUNNU6DEqOaLTicKnYZMnBWruljQxoaS3dZhocDge0bSTyOvdAbG5hxe2xji7E/L55xX13wWNDi6HCekcFxfCPGxY0MXC+s7afWaMdDyjyr+o8Rudm/NabOZvdl274zH4f5XK9z6On1Pe/K5TdPAslg77BjuO6Y3eO7GqvOPG/stknp1leyvLL0Z7bl9I4noMvLkzytLhWYzrOZzLXCORe028rORzOg4N/L0HlMOQ3Pgmnbb6KczlabORpu980q37TBqRu0/p3PO6234Bl03Ynuz+9W7gnsEcmvYaYY3aMYY0wx3pYd+ujsXauWdaY5Xkbtl23fPzFHiDB/QMo0yFjBllYxTQYYyxkrwn7JufwJ/PfgJ+C83X69ni6zvXcnyXabv0ncbLwsceS+RNlyN2mnneJtX0ngYO0+e+0+UnA+Wch3ji8hj5an4h+i6XBySU4n+R0roVcbw5yvHrmr4Yw8Y7x6c+9POPYHI5HI5HI5HI5HGXGww4nE4nrVyOR8XeqPEO7PLOiukYa3Novk5hV4cdtYZLI93e+uxff2jRo0aNGjRo0aNG1bVtW1dy3m83m8+tQ5ZzHw3nObwOu8La9Rc1dtkdS8A3eTk823tnktXWlxN6Oixe06zrN70Isd9jiOgZFq9yfkPqP/SLhN2Myl8jDM43bl1nbcb4cO57jlh8Jow6pzXZdL4dyODTuuhu77FyO27DdwdRxmvO+O+3N2+BdqyTwLHVczDVY4UPE4O66/ZO2cx1LFzVdSXtF7G4HMbrauOHRw6c8FdZ5m9fHZHYZXfTlZquyynSyTTKke6vcffSD9pzPA/G7n7jxPmuhc1DHMynPMrGL6AdewYmwu5ko+UUyTwrMv27rPH1v1nGqd87+p6N6LU8k3NEng53xXyHS97+44OSg/sy/hn+Se6yfYNjW0/uTgP+PvWYzLMmjhcLB/gGpri6H83/84eUXWT6T9Hsv7785z/7z4icpW+zfXypuR7rx/gMdZb1/wC678pcs8/2a3mDitGHxl9mfPlll5MafWWqxk/eYuTDgcNMzDGWLWvsuglNxs53GtN6uWpktlW1tZZYcuinMMWmnNnJydze3b2Y1McBxrBkXw799izLMZZYyy0TkbsGM4p03S2uVu5s/XXUdSdec6smVxZYYGpVmT8A+8ajuEyV5FatkvVru2x6uxGXXbH4A+jvgP4GMYy3iPLXzq/6z65+E005ey+cwMZD3fZcqc6xpjTFjQ0P3U+e++cPYmTIwj0nrK5NPTfl3WvpfLtXDcb2HQMudYOxFXQBor4L4T6vrOauFctYXJQ++NUWmJe5bmx1jDiZS1dTqWxo4GR8jm3fttpmPHppk9PEyv4/y8/sO07XacOmcqc0x2Vi9BvNJvN5oW8x4mOsydpidRxMYJPx06m1bqPzq9KtK8sxXNXFodD/+MYYaJTLwOhc9brCsV18oOR1i4tXChyTkq4lf4y1Ke+9axjDHqs1mfBbMXuP4Hzi+X7t8vzv7bHerrUPgPCxhjre4fXdfLNtNM+Jd+Zdh8xd8wP87uNPoPgv4W7/5P2BuxfsMabNnMnza+54Pdi5U671GPZY8CehX8Voeoo7FHpkeEc6715FwHZrIrUrHaviPUbPZHND+IhczrP6FcYvhOZ0Di/ETt0OI+YwNWR9r7tpf6WDeZKZDB1+z2IthOl1mPyb5FluvEx9h9d0NnM0Y1XPFkWIsk1WotJ0PBMmkvjvQTd0e71tfeV+8r8lQ/tpzpsmxJ+InrI/dj2UajUajVTUajatRqNRtGo1Go1Go4wjeMpZFMVV9CHbofPraLsJ3JpWV2XOoanCuFky4y3PPNxucK2uKC1Lbdb1eo+m5XomN6HfeZsabHLHRX/K+offtNGGmHWctcVcG44MdSqsOLY9VzX+Zxfxn2HPdWTpzWvkrtJ8M5zorrKcquRytJ5N5DZmcaW02l76nWO+BqPXm1A2Ry/0q71dH/mqrqeFjkYxjEXtsX8qubTk67rGycyqsdm4tZx5D6D5hhi0waaWmiaMP81Yjii5qxPlPuU/GfTL1Y5E6Jyfiq63qTa39A4J0sOGDgO9WF9bOXl0XfPRbsY2bPNKPy1YrFYrFYmRhhlTIyMjJWJYZHXuCXI8OoXsvfljGLFicNifpp2XunoPiG1wtx3p1Tah+/DD66OnVtVXP9rKbVxOnL0tR/rHtqB5UDErUVcl11D4qqvjpOcxX7armUNJB3LpW6bxVvD08e8h3odKKvyCFZBdSh2FVcST9xV3n3T8t1j7Kr9qgrqXg+13Pt5U7JCvFXVIV1YG5lRhkVYZJYYDDD4KOIMoHCp26WS8GB7uBh2zIdgq/PKyInjV2STShuoapUdCpX1yTwqq/z1VvET7Kh5nVPkO8YyxjLt2MaaMmWTLQvx3qnzltnXW0p2jxgbEtSny/Osv8Y9pLMXYoHVPAhkVdWVeODhR6q9/Sxe2liwwZWMVvFXfRkeIDxAePUPIrdJ4ey6yquzH+PD/bUOWAu05qVHtFd8rrKHSoeNIOUqrYr3FXyToqfYJgwmJdKpXXOwYYegNNGMzfZPp/t3t/DVs4zjNTN61rRqaWaa4NYbRjTa0tWwy2Y2tGN8ZO8ofNKq4j9SL7I+cSm4/6ovLV5HNXLI0jJidwrtk6ynCaP6Z++GjRlWS3tLeW129Mi9evxU9mtz6s5J3Z7M2ngTgnKvmpomxpaLCzPfmx0JWE+m3NLDDGOX47RctdYYNK5jakdqLkRlI39n590T5zctGSwwZZDJj6kW8XSi6ot2MmWWJ0DUT3nuvebBudScjZ79g8cWJ8av0k+/bE5WKd5MdbFpbDVMxu1DVMmtNZGJvq1mtRbn6M+g/kP0FwDwr7quZs7xosNGpbscyxhhd9TyJyFwbLcxlTasg75vW7TsV5K7ji44XPMMrdoj+Y3rT0Hie62nlYV/pwczzOmdLqLhYkzGMzCZWGMQzGMSsZYY6Di1t4nlJ+Em63mJxrVLxPbYxNEdgc1dU2iOKyoYYWjNrEeHTYybVk0atSa7ehuwsWMWTqn1TrnS6hYsi71d1+s+k+ic70e20fzE/VaTdxT9ZtU4GIXdeNx3X77guYYfpHeTQjaMX6brOu4OY4K7Y2d9mbHarI5ox3p4GpJ2Vd/Tst60f7j999pppjR+Q/Qf8J/VaORs3cji7FfFuN61+ui9s8hix1OCh5KGVV23BPXvZfz3CLyHpix+exi8z/KnCnosY2eunor+cxyPO/xJ0vKey9OvE9VjqaYu0x3Z3jd6o2b1T12D+F8l232lwaaacD5LE8LBxu7WTlbWraWpew8Xexjel3E+wWD4APITdNqR8F3R3T0lunCQ4GaE9R37DxeCYfcHi4xci5ovKfxVs55y2hf+65E/Xdp6jR5nrebTmi5incpkyOjs50JvrZwstbbW6kfuuQw+2mykf/EXNFzxfKTrxew929TR6bWnGL//F3JFOFCQT3K4lQ' # noqa + + kernels = Kernel( + bz2.decompress(base64.b64decode(quantization_code)), + [ + 'int4WeightCompression', + 'int4WeightExtractionFloat', + 'int4WeightExtractionHalf', + 'int8WeightExtractionFloat', + 'int8WeightExtractionHalf', + ], + ) +except Exception as exception: + kernels = None + logger.warning('Failed to load cpm_kernels:' + str(exception)) + + +class W8A16Linear(torch.autograd.Function): + + @staticmethod + def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, + scale_w: torch.Tensor, weight_bit_width): + ctx.inp_shape = inp.size() + ctx.weight_bit_width = weight_bit_width + out_features = quant_w.size(0) + inp = inp.contiguous().view(-1, inp.size(-1)) + weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width) + ctx.weight_shape = weight.size() + output = inp.mm(weight.t()) + ctx.save_for_backward(inp, quant_w, scale_w) + return output.view(*(ctx.inp_shape[:-1] + (out_features, ))) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + inp, quant_w, scale_w = ctx.saved_tensors + weight = extract_weight_to_half(quant_w, scale_w, ctx.weight_bit_width) + grad_output = grad_output.contiguous().view(-1, weight.size(0)) + grad_input = grad_output.mm(weight) + grad_weight = grad_output.t().mm(inp) + return grad_input.view(ctx.inp_shape), grad_weight.view( + ctx.weight_shape), None, None + + +def compress_int4_weight(weight: torch.Tensor): # (n, m) + with torch.cuda.device(weight.device): + n, m = weight.size(0), weight.size(1) + assert m % 2 == 0 + m = m // 2 + out = torch.empty(n, m, dtype=torch.int8, device='cuda') + stream = torch.cuda.current_stream() + + gridDim = (n, 1, 1) + blockDim = (min(round_up(m, 32), 1024), 1, 1) + + kernels.int4WeightCompression( + gridDim, + blockDim, + 0, + stream, + [ + ctypes.c_void_p(weight.data_ptr()), + ctypes.c_void_p(out.data_ptr()), + ctypes.c_int32(n), + ctypes.c_int32(m) + ], + ) + return out + + +def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, + source_bit_width: int): + assert scale_list.dtype in [torch.half, torch.bfloat16] + assert weight.dtype in [torch.int8] + if source_bit_width == 8: + return weight.to(scale_list.dtype) * scale_list[:, None] + elif source_bit_width == 4: + func = ( + kernels.int4WeightExtractionHalf if scale_list.dtype == torch.half + else kernels.int4WeightExtractionBFloat16) + else: + assert False, 'Unsupported bit-width' + + with torch.cuda.device(weight.device): + n, m = weight.size(0), weight.size(1) + out = torch.empty( + n, + m * (8 // source_bit_width), + dtype=scale_list.dtype, + device='cuda') + stream = torch.cuda.current_stream() + + gridDim = (n, 1, 1) + blockDim = (min(round_up(m, 32), 1024), 1, 1) + + func( + gridDim, + blockDim, + 0, + stream, + [ + ctypes.c_void_p(weight.data_ptr()), + ctypes.c_void_p(scale_list.data_ptr()), + ctypes.c_void_p(out.data_ptr()), + ctypes.c_int32(n), + ctypes.c_int32(m), + ], + ) + return out + + +class QuantizedLinear(torch.nn.Module): + + def __init__(self, + weight_bit_width: int, + weight, + bias=None, + device='cpu', + dtype=None, + empty_init=False, + *args, + **kwargs): + super().__init__() + self.weight_bit_width = weight_bit_width + + shape = weight.shape + + if weight is None or empty_init: + self.weight = torch.empty( + shape[0], + shape[1] * weight_bit_width // 8, + dtype=torch.int8, + device=device) + self.weight_scale = torch.empty( + shape[0], dtype=dtype, device=device) + else: + self.weight_scale = weight.abs().max(dim=-1).values / ( + (2**(weight_bit_width - 1)) - 1) + self.weight = torch.round(weight / self.weight_scale[:, None]).to( + torch.int8) + if weight_bit_width == 4: + self.weight = compress_int4_weight(self.weight) + + self.weight = Parameter(self.weight.to(device), requires_grad=False) + self.weight_scale = Parameter( + self.weight_scale.to(device), requires_grad=False) + self.bias = Parameter( + bias.to(device), requires_grad=False) if bias is not None else None + + def forward(self, input): + output = W8A16Linear.apply(input, self.weight, self.weight_scale, + self.weight_bit_width) + if self.bias is not None: + output = output + self.bias + return output + + +def quantize(model, weight_bit_width, empty_init=False, device=None): + """Replace fp16 linear with quantized linear""" + for layer in model.layers: + layer.self_attention.query_key_value = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight=layer.self_attention.query_key_value.weight.to( + torch.cuda.current_device()), + bias=layer.self_attention.query_key_value.bias, + dtype=layer.self_attention.query_key_value.weight.dtype, + device=layer.self_attention.query_key_value.weight.device + if device is None else device, + empty_init=empty_init) + layer.self_attention.dense = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight=layer.self_attention.dense.weight.to( + torch.cuda.current_device()), + bias=layer.self_attention.dense.bias, + dtype=layer.self_attention.dense.weight.dtype, + device=layer.self_attention.dense.weight.device + if device is None else device, + empty_init=empty_init) + layer.mlp.dense_h_to_4h = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight=layer.mlp.dense_h_to_4h.weight.to( + torch.cuda.current_device()), + bias=layer.mlp.dense_h_to_4h.bias, + dtype=layer.mlp.dense_h_to_4h.weight.dtype, + device=layer.mlp.dense_h_to_4h.weight.device + if device is None else device, + empty_init=empty_init) + layer.mlp.dense_4h_to_h = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight=layer.mlp.dense_4h_to_h.weight.to( + torch.cuda.current_device()), + bias=layer.mlp.dense_4h_to_h.bias, + dtype=layer.mlp.dense_4h_to_h.weight.dtype, + device=layer.mlp.dense_4h_to_h.weight.device + if device is None else device, + empty_init=empty_init) + + return model diff --git a/modelscope/models/nlp/chatglm2/text_generation.py b/modelscope/models/nlp/chatglm2/text_generation.py new file mode 100644 index 00000000..be744f14 --- /dev/null +++ b/modelscope/models/nlp/chatglm2/text_generation.py @@ -0,0 +1,1299 @@ +""" PyTorch ChatGLM model. """ + +import copy +import math +import re +import sys +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss, LayerNorm +from torch.nn.utils import skip_init +from transformers.generation.logits_process import LogitsProcessor +from transformers.generation.utils import (GenerationConfig, + LogitsProcessorList, ModelOutput, + StoppingCriteriaList) +from transformers.modeling_outputs import (BaseModelOutputWithPast, + CausalLMOutputWithPast) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging + +from modelscope.metainfo import Models +from modelscope.models import MODELS, Model, TorchModel +from modelscope.outputs import OutputKeys +from modelscope.utils.constant import Tasks +from .configuration import ChatGLM2Config + +# flags required to enable jit fusion kernels + +if sys.platform != 'darwin': + torch._C._jit_set_profiling_mode(False) + torch._C._jit_set_profiling_executor(False) + torch._C._jit_override_can_fuse_on_cpu(True) + torch._C._jit_override_can_fuse_on_gpu(True) + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = 'THUDM/ChatGLM2-6B' +_CONFIG_FOR_DOC = 'ChatGLM6BConfig' + +CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [ + 'THUDM/chatglm2-6b', + # See all ChatGLM models at https://huggingface.co/models?filter=chatglm +] + + +def default_init(cls, *args, **kwargs): + return cls(*args, **kwargs) + + +class InvalidScoreLogitsProcessor(LogitsProcessor): + + def __call__(self, input_ids: torch.LongTensor, + scores: torch.FloatTensor) -> torch.FloatTensor: + if torch.isnan(scores).any() or torch.isinf(scores).any(): + scores.zero_() + scores[..., 5] = 5e4 + return scores + + +def split_tensor_along_last_dim( + tensor: torch.Tensor, + num_partitions: int, + contiguous_split_chunks: bool = False, +) -> List[torch.Tensor]: + """Split a tensor along its last dimension. + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + Returns: + A list of Tensors + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = tensor.size()[last_dim] // num_partitions + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + +class RotaryEmbedding(nn.Module): + + def __init__(self, dim, original_impl=False, device=None, dtype=None): + super().__init__() + inv_freq = 1.0 / (10000**( + torch.arange(0, dim, 2, device=device, dtype=dtype) / dim)) + self.register_buffer('inv_freq', inv_freq) + self.dim = dim + self.original_impl = original_impl + + def forward_impl(self, + seq_len: int, + n_elem: int, + dtype: torch.dtype, + device: torch.device, + base: int = 10000): + """Enhanced Transformer with Rotary Position Embedding. + Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ + transformers/rope/__init__.py. MIT License: + https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. + """ + # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + theta = 1.0 / ( + base**(torch.arange(0, n_elem, 2, dtype=dtype, device=device) + / n_elem)) + + # Create position indexes `[0, 1, ..., seq_len - 1]` + seq_idx = torch.arange(seq_len, dtype=dtype, device=device) + + # Calculate the product of position index and $\theta_i$ + idx_theta = torch.outer(seq_idx, theta).float() + + cache = torch.stack( + [torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) + + # this is to mimic the behaviour of complex32, else we will get different results + if dtype in (torch.float16, torch.bfloat16, torch.int8): + cache = cache.bfloat16( + ) if dtype == torch.bfloat16 else cache.half() + return cache + + def forward(self, max_seq_len, offset=0): + return self.forward_impl( + max_seq_len, + self.dim, + dtype=self.inv_freq.dtype, + device=self.inv_freq.device) + + +@torch.jit.script +def apply_rotary_pos_emb(x: torch.Tensor, + rope_cache: torch.Tensor) -> torch.Tensor: + # x: [sq, b, np, hn] + sq, _, np, _ = x.size(0), x.size(1), x.size(2), x.size(3) + rot_dim = rope_cache.shape[-2] * 2 + x, x_pass = x[..., :rot_dim], x[..., rot_dim:] + # truncate to support variable sizes + rope_cache = rope_cache[:sq] + xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2) + rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * rope_cache[..., 0] + - xshaped[..., 1] * rope_cache[..., 1], + xshaped[..., 1] * rope_cache[..., 0] + + xshaped[..., 0] * rope_cache[..., 1], + ], + -1, + ) + x_out2 = x_out2.flatten(3) + return torch.cat((x_out2, x_pass), dim=-1) + + +class RMSNorm(torch.nn.Module): + + def __init__(self, + normalized_shape, + eps=1e-5, + device=None, + dtype=None, + **kwargs): + super().__init__() + self.weight = torch.nn.Parameter( + torch.empty(normalized_shape, device=device, dtype=dtype)) + self.eps = eps + + def forward(self, hidden_states: torch.Tensor): + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean( + -1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + + return (self.weight * hidden_states).to(input_dtype) + + +class CoreAttention(torch.nn.Module): + + def __init__(self, config: ChatGLM2Config, layer_number): + super(CoreAttention, self).__init__() + + self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 + if self.apply_query_key_layer_scaling: + self.attention_softmax_in_fp32 = True + self.layer_number = max(1, layer_number) + + projection_size = config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + self.hidden_size_per_partition = projection_size + self.hidden_size_per_attention_head = projection_size // config.num_attention_heads + self.num_attention_heads_per_partition = config.num_attention_heads + + coeff = None + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + if self.apply_query_key_layer_scaling: + coeff = self.layer_number + self.norm_factor *= coeff + self.coeff = coeff + + self.attention_dropout = torch.nn.Dropout(config.attention_dropout) + + def forward(self, query_layer, key_layer, value_layer, attention_mask): + pytorch_major_version = int(torch.__version__.split('.')[0]) + if pytorch_major_version >= 2: + query_layer, key_layer, value_layer = [ + k.permute(1, 2, 0, 3) + for k in [query_layer, key_layer, value_layer] + ] + if attention_mask is None and query_layer.shape[ + 2] == key_layer.shape[2]: + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, key_layer, value_layer, is_causal=True) + else: + if attention_mask is not None: + attention_mask = ~attention_mask + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, key_layer, value_layer, attention_mask) + context_layer = context_layer.permute(2, 0, 1, 3) + new_context_layer_shape = context_layer.size()[:-2] + ( + self.hidden_size_per_partition, ) + context_layer = context_layer.reshape(*new_context_layer_shape) + else: + # Raw attention scores + + # [b, np, sq, sk] + output_size = (query_layer.size(1), query_layer.size(2), + query_layer.size(0), key_layer.size(0)) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view(output_size[2], + output_size[0] * output_size[1], -1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view(output_size[3], + output_size[0] * output_size[1], -1) + + # preallocting input tensor: [b * np, sq, sk] + matmul_input_buffer = torch.empty( + output_size[0] * output_size[1], + output_size[2], + output_size[3], + dtype=query_layer.dtype, + device=query_layer.device) + + # Raw attention scores. [b * np, sq, sk] + matmul_result = torch.baddbmm( + matmul_input_buffer, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=(1.0 / self.norm_factor), + ) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + # =========================== + # Attention probs and dropout + # =========================== + + # attention scores and attention mask [b, np, sq, sk] + if self.attention_softmax_in_fp32: + attention_scores = attention_scores.float() + if self.coeff is not None: + attention_scores = attention_scores * self.coeff + if attention_mask is None and attention_scores.shape[ + 2] == attention_scores.shape[3]: + attention_mask = torch.ones( + output_size[0], + 1, + output_size[2], + output_size[3], + device=attention_scores.device, + dtype=torch.bool) + attention_mask.tril_() + attention_mask = ~attention_mask + if attention_mask is not None: + attention_scores = attention_scores.masked_fill( + attention_mask, float('-inf')) + attention_probs = F.softmax(attention_scores, dim=-1) + attention_probs = attention_probs.type_as(value_layer) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.attention_dropout(attention_probs) + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sk, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(1), value_layer.size(2), + query_layer.size(0), value_layer.size(3)) + # change view [sk, b * np, hn] + value_layer = value_layer.view( + value_layer.size(0), output_size[0] * output_size[1], -1) + # change view [b * np, sq, sk] + attention_probs = attention_probs.view( + output_size[0] * output_size[1], output_size[2], -1) + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, + value_layer.transpose(0, 1)) + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + ( + self.hidden_size_per_partition, ) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer + + +class SelfAttention(torch.nn.Module): + """Parallel self-attention layer abstract class. + Self-attention layer takes input with size [s, b, h] + and returns output of the same size. + """ + + def __init__(self, config: ChatGLM2Config, layer_number, device=None): + super(SelfAttention, self).__init__() + self.layer_number = max(1, layer_number) + + self.projection_size = config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads + self.num_attention_heads_per_partition = config.num_attention_heads + + self.multi_query_attention = config.multi_query_attention + self.qkv_hidden_size = 3 * self.projection_size + if self.multi_query_attention: + self.num_multi_query_groups_per_partition = config.multi_query_group_num + self.qkv_hidden_size = ( + self.projection_size + 2 * self.hidden_size_per_attention_head + * config.multi_query_group_num) + self.query_key_value = nn.Linear( + config.hidden_size, + self.qkv_hidden_size, + bias=config.add_bias_linear or config.add_qkv_bias, + device=device, + **_config_to_kwargs(config)) + + self.core_attention = CoreAttention(config, self.layer_number) + + # Output. + self.dense = nn.Linear( + self.projection_size, + config.hidden_size, + bias=config.add_bias_linear, + device=device, + **_config_to_kwargs(config)) + + def _allocate_memory(self, + inference_max_sequence_len, + batch_size, + device=None, + dtype=None): + if self.multi_query_attention: + num_attention_heads = self.num_multi_query_groups_per_partition + else: + num_attention_heads = self.num_attention_heads_per_partition + return torch.empty( + inference_max_sequence_len, + batch_size, + num_attention_heads, + self.hidden_size_per_attention_head, + dtype=dtype, + device=device, + ) + + def forward(self, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=None, + use_cache=True): + # hidden_states: [sq, b, h] + + # ================================================= + # Pre-allocate memory for key-values for inference. + # ================================================= + # ===================== + # Query, Key, and Value + # ===================== + + # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] + mixed_x_layer = self.query_key_value(hidden_states) + + if self.multi_query_attention: + (query_layer, key_layer, value_layer) = mixed_x_layer.split( + [ + self.num_attention_heads_per_partition + * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition + * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition + * self.hidden_size_per_attention_head, + ], + dim=-1, + ) + query_layer = query_layer.view(query_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head)) + key_layer = key_layer.view(key_layer.size()[:-1] + ( + self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head)) + value_layer = value_layer.view(value_layer.size()[:-1] + ( + self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head)) + else: + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, # noqa + 3 * self.hidden_size_per_attention_head) # noqa + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + (query_layer, key_layer, + value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) + + # apply relative positional encoding (rotary embedding) + if rotary_pos_emb is not None: + query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) + key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) + + # adjust key and value for inference + if use_cache: + if kv_cache is not None: + cache_k, cache_v = kv_cache + key_layer = torch.cat((cache_k, key_layer), dim=0) + value_layer = torch.cat((cache_v, value_layer), dim=0) + kv_cache = (key_layer, value_layer) + else: + kv_cache = None + + if self.multi_query_attention: + key_layer = key_layer.unsqueeze(-2) + key_layer = key_layer.expand( + -1, -1, -1, self.num_attention_heads_per_partition + // self.num_multi_query_groups_per_partition, -1) + key_layer = key_layer.contiguous().view( + key_layer.size()[:2] + (self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head)) + value_layer = value_layer.unsqueeze(-2) + value_layer = value_layer.expand( + -1, -1, -1, self.num_attention_heads_per_partition + // self.num_multi_query_groups_per_partition, -1) + value_layer = value_layer.contiguous().view( + value_layer.size()[:2] + + (self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head)) + + # ================================== + # core attention computation + # ================================== + + context_layer = self.core_attention(query_layer, key_layer, + value_layer, attention_mask) + + # ================= + # Output. [sq, b, h] + # ================= + + output = self.dense(context_layer) + + return output, kv_cache + + +def _config_to_kwargs(args): + common_kwargs = { + 'dtype': args.torch_dtype, + } + return common_kwargs + + +class MLP(torch.nn.Module): + """MLP. + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. + """ + + def __init__(self, config: ChatGLM2Config, device=None): + super(MLP, self).__init__() + + self.add_bias = config.add_bias_linear + + # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf + self.dense_h_to_4h = nn.Linear( + config.hidden_size, + config.ffn_hidden_size * 2, + bias=self.add_bias, + device=device, + **_config_to_kwargs(config)) + + def swiglu(x): + x = torch.chunk(x, 2, dim=-1) + return F.silu(x[0]) * x[1] + + self.activation_func = swiglu + + # Project back to h. + self.dense_4h_to_h = nn.Linear( + config.ffn_hidden_size, + config.hidden_size, + bias=self.add_bias, + device=device, + **_config_to_kwargs(config)) + + def forward(self, hidden_states): + # [s, b, 4hp] + intermediate_parallel = self.dense_h_to_4h(hidden_states) + intermediate_parallel = self.activation_func(intermediate_parallel) + # [s, b, h] + output = self.dense_4h_to_h(intermediate_parallel) + return output + + +class GLMBlock(torch.nn.Module): + """A single transformer layer. + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + """ + + def __init__(self, config: ChatGLM2Config, layer_number, device=None): + super(GLMBlock, self).__init__() + self.layer_number = layer_number + + self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm + + self.fp32_residual_connection = config.fp32_residual_connection + + LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm + # Layernorm on the input data. + self.input_layernorm = LayerNormFunc( + config.hidden_size, + eps=config.layernorm_epsilon, + device=device, + dtype=config.torch_dtype) + + # Self attention. + self.self_attention = SelfAttention( + config, layer_number, device=device) + self.hidden_dropout = config.hidden_dropout + + # Layernorm on the attention output + self.post_attention_layernorm = LayerNormFunc( + config.hidden_size, + eps=config.layernorm_epsilon, + device=device, + dtype=config.torch_dtype) + + # MLP + self.mlp = MLP(config, device=device) + + def forward( + self, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=None, + use_cache=True, + ): + # hidden_states: [s, b, h] + + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output, kv_cache = self.self_attention( + layernorm_output, + attention_mask, + rotary_pos_emb, + kv_cache=kv_cache, + use_cache=use_cache) + + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + layernorm_input = torch.nn.functional.dropout( + attention_output, p=self.hidden_dropout, training=self.training) + layernorm_input = residual + layernorm_input + + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + + # MLP. + mlp_output = self.mlp(layernorm_output) + + # Second residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + output = torch.nn.functional.dropout( + mlp_output, p=self.hidden_dropout, training=self.training) + output = residual + output + + return output, kv_cache + + +class GLMTransformer(torch.nn.Module): + """Transformer class.""" + + def __init__(self, config: ChatGLM2Config, device=None): + super(GLMTransformer, self).__init__() + + self.fp32_residual_connection = config.fp32_residual_connection + self.post_layer_norm = config.post_layer_norm + + # Number of layers. + self.num_layers = config.num_layers + + # Transformer layers. + def build_layer(layer_number): + return GLMBlock(config, layer_number, device=device) + + self.layers = torch.nn.ModuleList( + [build_layer(i + 1) for i in range(self.num_layers)]) + + if self.post_layer_norm: + LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm + # Final layer norm before output. + self.final_layernorm = LayerNormFunc( + config.hidden_size, + eps=config.layernorm_epsilon, + device=device, + dtype=config.torch_dtype) + + def _get_layer(self, layer_number): + return self.layers[layer_number] + + def forward( + self, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_caches=None, + use_cache: Optional[bool] = True, + output_hidden_states: Optional[bool] = False, + ): + if not kv_caches: + kv_caches = [None for _ in range(self.num_layers)] + presents = () if use_cache else None + all_self_attentions = None + all_hidden_states = () if output_hidden_states else None + for index in range(self.num_layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + layer = self._get_layer(index) + + hidden_states, kv_cache = layer( + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=kv_caches[index], + use_cache=use_cache) + if use_cache: + presents = presents + (kv_cache, ) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + # Final layer norm. + if self.post_layer_norm: + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states, presents, all_hidden_states, all_self_attentions + + +class ChatGLMPreTrainedModel(TorchModel, PreTrainedModel): + """ + An abstract class to handle weights initialization and + a simple interface for downloading and loading pretrained models. + """ + + is_parallelizable = False + supports_gradient_checkpointing = True + config_class = ChatGLM2Config + base_model_prefix = 'transformer' + _no_split_modules = ['GLMBlock'] + + def __init__(self, config, **kwargs): + super().__init__(config.name_or_path, **kwargs) + super(Model, self).__init__(config) + + def _init_weights(self, module: nn.Module): + """Initialize the weights.""" + return + + def get_masks(self, input_ids, past_key_values, padding_mask=None): + batch_size, seq_length = input_ids.shape + full_attention_mask = torch.ones( + batch_size, seq_length, seq_length, device=input_ids.device) + full_attention_mask.tril_() + past_length = 0 + if past_key_values: + past_length = past_key_values[0][0].shape[0] + if past_length: + full_attention_mask = torch.cat( + ( + torch.ones( # noqa + batch_size, + seq_length, + past_length, # noqa + device=input_ids.device), + full_attention_mask), # noqa + dim=-1) # noqa + if padding_mask is not None: + full_attention_mask = full_attention_mask * padding_mask.unsqueeze( + 1) # noqa + if not past_length and padding_mask is not None: + full_attention_mask -= padding_mask.unsqueeze(-1) - 1 + full_attention_mask = (full_attention_mask < 0.5).bool() + full_attention_mask.unsqueeze_(1) + return full_attention_mask + + def get_position_ids(self, input_ids, device): + batch_size, seq_length = input_ids.shape + position_ids = torch.arange( + seq_length, dtype=torch.long, + device=device).unsqueeze(0).repeat(batch_size, 1) + return position_ids + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, ChatGLMModel): + module.gradient_checkpointing = value + + @classmethod + def _instantiate(cls, **kwargs): + """Instantiate the model. + + Args: + kwargs: Input args. + model_dir: The model dir used to load the checkpoint and the label information. + + Returns: + The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained + """ + + model_dir = kwargs.pop('model_dir', None) + kwargs.pop('cfg', None) + model = super(Model, cls).from_pretrained( + pretrained_model_name_or_path=model_dir, **kwargs) + model.model_dir = model_dir + return model + + +class Embedding(torch.nn.Module): + """Language model embeddings.""" + + def __init__(self, config: ChatGLM2Config, device=None): + super(Embedding, self).__init__() + + self.hidden_size = config.hidden_size + # Word embeddings (parallel). + self.word_embeddings = nn.Embedding( + config.padded_vocab_size, + self.hidden_size, + dtype=config.torch_dtype, + device=device) + self.fp32_residual_connection = config.fp32_residual_connection + + def forward(self, input_ids): + # Embeddings. + words_embeddings = self.word_embeddings(input_ids) + embeddings = words_embeddings + # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. + embeddings = embeddings.transpose(0, 1).contiguous() + # If the input flag for fp32 residual connection is set, convert for float. + if self.fp32_residual_connection: + embeddings = embeddings.float() + return embeddings + + +class ChatGLMModel(ChatGLMPreTrainedModel): + + def __init__(self, config: ChatGLM2Config, device=None, empty_init=True): + super().__init__(config) + if empty_init: + init_method = skip_init + else: + init_method = default_init + init_kwargs = {} + if device is not None: + init_kwargs['device'] = device + self.embedding = init_method(Embedding, config, **init_kwargs) + + # Rotary positional embeddings + self.seq_length = config.seq_length + rotary_dim = ( + config.hidden_size // config.num_attention_heads + if config.kv_channels is None else config.kv_channels) + + self.rotary_pos_emb = RotaryEmbedding( + rotary_dim // 2, + original_impl=config.original_rope, + device=device, + dtype=config.torch_dtype) + self.encoder = init_method(GLMTransformer, config, **init_kwargs) + self.output_layer = init_method( + nn.Linear, + config.hidden_size, + config.padded_vocab_size, + bias=False, + dtype=config.torch_dtype, + **init_kwargs) + self.gradient_checkpointing = False + + def forward( + self, + input_ids, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + full_attention_mask: Optional[torch.BoolTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], + ...]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, seq_length = input_ids.shape + + if inputs_embeds is None: + inputs_embeds = self.embedding(input_ids) + + if full_attention_mask is None: + if (attention_mask is not None + and not attention_mask.all()) or (past_key_values + and seq_length != 1): + full_attention_mask = self.get_masks( + input_ids, past_key_values, padding_mask=attention_mask) + + # Rotary positional embeddings + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + if position_ids is not None: + rotary_pos_emb = rotary_pos_emb[position_ids] + else: + rotary_pos_emb = rotary_pos_emb[None, :seq_length] + rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() + + # Run encoder. + hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( + inputs_embeds, + full_attention_mask, + rotary_pos_emb=rotary_pos_emb, + kv_caches=past_key_values, + use_cache=use_cache, + output_hidden_states=output_hidden_states) + + if not return_dict: + return tuple(v for v in [ + hidden_states, presents, all_hidden_states, all_self_attentions + ] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + def quantize(self, weight_bit_width: int): + from .quantization import quantize + quantize(self.encoder, weight_bit_width) + return self + + +@MODELS.register_module(Tasks.chat, module_name=Models.chatglm2_6b) +class ChatGLM2ForConditionalGeneration(ChatGLMPreTrainedModel): + + def __init__(self, config: ChatGLM2Config, empty_init=True, device=None): + super().__init__(config) + + self.max_sequence_length = config.max_length + self.transformer = ChatGLMModel( + config, empty_init=empty_init, device=device) + self.config = config + self.quantized = False + + if self.config.quantization_bit: + self.quantize(self.config.quantization_bit, empty_init=True) + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + standardize_cache_format: bool = False, + ) -> Dict[str, Any]: + # update past_key_values + model_kwargs['past_key_values'] = self._extract_past_from_model_output( + outputs, standardize_cache_format=standardize_cache_format) + + # update attention mask + if 'attention_mask' in model_kwargs: + attention_mask = model_kwargs['attention_mask'] + model_kwargs['attention_mask'] = torch.cat( + [ # noqa + attention_mask, # noqa + attention_mask.new_ones( + (attention_mask.shape[0], 1)) # noqa + ], + dim=-1) # noqa + + # update position ids + if 'position_ids' in model_kwargs: + position_ids = model_kwargs['position_ids'] + new_position_id = position_ids[..., -1:].clone() + new_position_id += 1 + model_kwargs['position_ids'] = torch.cat( + [position_ids, new_position_id], dim=-1) + + model_kwargs['is_first_forward'] = False + return model_kwargs + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + is_first_forward: bool = True, + **kwargs) -> dict: + # only last token for input_ids if past is not None + if position_ids is None: + position_ids = self.get_position_ids( + input_ids, device=input_ids.device) + if not is_first_forward: + position_ids = position_ids[..., -1:] + input_ids = input_ids[:, -1:] + return { + 'input_ids': input_ids, + 'past_key_values': past_key_values, + 'position_ids': position_ids, + 'attention_mask': attention_mask, + 'return_last_logit': True + } + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + return_last_logit: Optional[bool] = False, + ): + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + if return_last_logit: + hidden_states = hidden_states[-1:] + lm_logits = self.transformer.output_layer(hidden_states) + lm_logits = lm_logits.transpose(0, 1).contiguous() + + loss = None + if labels is not None: + lm_logits = lm_logits.to(torch.float32) + + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1)) + + lm_logits = lm_logits.to(hidden_states.dtype) + loss = loss.to(hidden_states.dtype) + + if not return_dict: + output = (lm_logits, ) + transformer_outputs[1:] + return ((loss, ) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache( + past: Tuple[Tuple[torch.Tensor, torch.Tensor], + ...], beam_idx: torch.LongTensor + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + Output shares the same memory storage as `past`. + """ + return tuple(( + layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), + layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), + ) for layer_past in past) + + def process_response(self, response): + response = response.strip() + response = response.replace('[[训练时间]]', '2023年') + return response + + def build_inputs(self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = None): + prompt = '' + for i, (old_query, response) in enumerate(history): + prompt += '[Round {}]\n\n问:{}\n\n答:{}\n\n'.format( + i + 1, old_query, response) + prompt += '[Round {}]\n\n问:{}\n\n答:'.format(len(history) + 1, query) + inputs = tokenizer([prompt], return_tensors='pt') + inputs = inputs.to(self.device) + return inputs + + def build_stream_inputs(self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = None): + if history: + prompt = '\n\n[Round {}]\n\n问:{}\n\n答:'.format( + len(history) + 1, query) + input_ids = tokenizer.encode(prompt, add_special_tokens=False) + input_ids = input_ids[1:] + inputs = tokenizer.batch_encode_plus([(input_ids, None)], + return_tensors='pt', + add_special_tokens=False) + else: + prompt = '[Round {}]\n\n问:{}\n\n答:'.format(len(history) + 1, query) + inputs = tokenizer([prompt], return_tensors='pt') + inputs = inputs.to(self.device) + return inputs + + @torch.no_grad() + def chat(self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = None, + max_length: int = 2048, + num_beams=1, + do_sample=True, + top_p=0.8, + temperature=0.8, + logits_processor=None, + **kwargs): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + gen_kwargs = { + 'max_length': max_length, + 'num_beams': num_beams, + 'do_sample': do_sample, + 'top_p': top_p, + 'temperature': temperature, + 'logits_processor': logits_processor, + **kwargs + } + inputs = self.build_inputs(tokenizer, query, history=history) + outputs = self.generate(**inputs, **gen_kwargs) + outputs = outputs.tolist()[0][len(inputs['input_ids'][0]):] + response = tokenizer.decode(outputs) + response = self.process_response(response) + history = history + [(query, response)] + return {OutputKeys.RESPONSE: response, OutputKeys.HISTORY: history} + + @torch.no_grad() + def stream_chat(self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = None, + past_key_values=None, + max_length: int = 2048, + do_sample=True, + top_p=0.8, + temperature=0.8, + logits_processor=None, + return_past_key_values=False, + **kwargs): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + gen_kwargs = { + 'max_length': max_length, + 'do_sample': do_sample, + 'top_p': top_p, + 'temperature': temperature, + 'logits_processor': logits_processor, + **kwargs + } + if past_key_values is None and not return_past_key_values: + inputs = self.build_inputs(tokenizer, query, history=history) + else: + inputs = self.build_stream_inputs( + tokenizer, query, history=history) + if past_key_values is not None: + past_length = past_key_values[0][0].shape[0] + inputs.position_ids += past_length + attention_mask = inputs.attention_mask + attention_mask = torch.cat( + (attention_mask.new_ones(1, past_length), attention_mask), + dim=1) + inputs['attention_mask'] = attention_mask + for outputs in self.stream_generate( + **inputs, + past_key_values=past_key_values, + return_past_key_values=return_past_key_values, + **gen_kwargs): + if return_past_key_values: + outputs, past_key_values = outputs + outputs = outputs.tolist()[0][len(inputs['input_ids'][0]):] + response = tokenizer.decode(outputs) + response = self.process_response(response) + new_history = history + [(query, response)] + if return_past_key_values: + yield response, new_history, past_key_values + else: + yield response, new_history + + @torch.no_grad() + def stream_generate( + self, + input_ids, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], + List[int]]] = None, + return_past_key_values=False, + **kwargs, + ): + _, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] + + if generation_config is None: + generation_config = self.generation_config + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) + _, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id + + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + + has_default_max_length = kwargs.get( + 'max_length') is None and generation_config.max_length is not None + if has_default_max_length and generation_config.max_new_tokens is None: + warnings.warn( + f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " + 'This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we' + ' recommend using `max_new_tokens` to control the maximum length of the generation.', + UserWarning, + ) + elif generation_config.max_new_tokens is not None: + generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length + if not has_default_max_length: + logger.warn( + f'Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=' + f'{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. ' + 'Please refer to the documentation for more information. ' + '(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)', + UserWarning, + ) + + if input_ids_seq_length >= generation_config.max_length: + input_ids_string = 'decoder_input_ids' if self.config.is_encoder_decoder else 'input_ids' + logger.warning( + f'Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to' + f' {generation_config.max_length}. This can lead to unexpected behavior. You should consider' + ' increasing `max_new_tokens`.') + + # 2. Set generation parameters if not already defined + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList( + ) + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList( + ) + + logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=input_ids, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + ) + + stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, + stopping_criteria=stopping_criteria) + logits_warper = self._get_logits_warper(generation_config) + + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + scores = None + while True: + model_inputs = self.prepare_inputs_for_generation( + input_ids, **model_kwargs) + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + if generation_config.do_sample: + next_tokens = torch.multinomial( + probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(probs, dim=-1) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder) + unfinished_sequences = unfinished_sequences.mul( + (sum(next_tokens != i for i in eos_token_id)).long()) + if return_past_key_values: + yield input_ids, outputs.past_key_values + else: + yield input_ids + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria( + input_ids, scores): + break + + def quantize(self, bits: int, empty_init=False, device=None, **kwargs): + if bits == 0: + return + + from .quantization import quantize + + if self.quantized: + logger.info('Already quantized.') + return self + + self.quantized = True + + self.config.quantization_bit = bits + + self.transformer.encoder = quantize( + self.transformer.encoder, + bits, + empty_init=empty_init, + device=device, + **kwargs) + return self diff --git a/modelscope/models/nlp/chatglm2/tokenization.py b/modelscope/models/nlp/chatglm2/tokenization.py new file mode 100644 index 00000000..5036d881 --- /dev/null +++ b/modelscope/models/nlp/chatglm2/tokenization.py @@ -0,0 +1,251 @@ +"""Tokenization classes for ChatGLM.""" +import os +from typing import Dict, List, Optional, Union + +from sentencepiece import SentencePieceProcessor +from transformers.tokenization_utils import PreTrainedTokenizer +from transformers.tokenization_utils_base import BatchEncoding, EncodedInput +from transformers.utils import PaddingStrategy, logging + +logger = logging.get_logger(__name__) + + +class SPTokenizer: + + def __init__(self, model_path: str): + # reload tokenizer + assert os.path.isfile(model_path), model_path + self.sp_model = SentencePieceProcessor(model_file=model_path) + + # BOS / EOS token IDs + self.n_words: int = self.sp_model.vocab_size() + self.bos_id: int = self.sp_model.bos_id() + self.eos_id: int = self.sp_model.eos_id() + self.pad_id: int = self.sp_model.eos_id() + assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() + + special_tokens = ['[MASK]', '[gMASK]', '[sMASK]', 'sop', 'eop'] + self.special_tokens = {} + self.index_special_tokens = {} + for token in special_tokens: + self.special_tokens[token] = self.n_words + self.index_special_tokens[self.n_words] = token + self.n_words += 1 + + def tokenize(self, s: str): + return self.sp_model.EncodeAsPieces(s) + + def encode(self, + s: str, + bos: bool = False, + eos: bool = False) -> List[int]: + assert type(s) is str + t = self.sp_model.encode(s) + if bos: + t = [self.bos_id] + t + if eos: + t = t + [self.eos_id] + return t + + def decode(self, t: List[int]) -> str: + return self.sp_model.decode(t) + + def decode_tokens(self, tokens: List[str]) -> str: + text = self.sp_model.DecodePieces(tokens) + return text + + def convert_token_to_id(self, token): + """ Converts a token (str) in an id using the vocab. """ + if token in self.special_tokens: + return self.special_tokens[token] + return self.sp_model.PieceToId(token) + + def convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + if index in self.index_special_tokens: + return '' + return self.sp_model.IdToPiece(index) + + +class ChatGLM2Tokenizer(PreTrainedTokenizer): + vocab_files_names = {'vocab_file': 'tokenizer.model'} + + model_input_names = ['input_ids', 'attention_mask', 'position_ids'] + + def __init__(self, vocab_file, padding_side='left', **kwargs): + super().__init__(padding_side=padding_side, **kwargs) + self.name = 'GLMTokenizer' + + self.tokenizer = SPTokenizer(vocab_file) + self.special_tokens = { + '': self.tokenizer.bos_id, + '': self.tokenizer.eos_id, + '': self.tokenizer.pad_id + } + + def get_command(self, token): + if token in self.special_tokens: + return self.special_tokens[token] + assert token in self.tokenizer.special_tokens, f'{token} is not a special token for {self.name}' + return self.tokenizer.special_tokens[token] + + @property + def pad_token(self) -> str: + return '' + + @property + def pad_token_id(self): + return self.get_command('') + + @property + def eos_token_id(self): + return self.get_command('') + + @property + def vocab_size(self): + return self.tokenizer.n_words + + def get_vocab(self): + """ Returns vocab as a dict """ + vocab = { + self._convert_id_to_token(i): i + for i in range(self.vocab_size) + } + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text, **kwargs): + return self.tokenizer.tokenize(text) + + def _convert_token_to_id(self, token): + """ Converts a token (str) in an id using the vocab. """ + return self.tokenizer.convert_token_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.tokenizer.convert_id_to_token(index) + + def convert_tokens_to_string(self, tokens: List[str]) -> str: + return self.tokenizer.decode_tokens(tokens) + + def save_vocabulary(self, save_directory, filename_prefix=None): + """ + Save the vocabulary and special tokens file to a directory. + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + filename_prefix (`str`, *optional*): + An optional prefix to add to the named of the saved files. + Returns: + `Tuple(str)`: Paths to the files saved. + """ + if os.path.isdir(save_directory): + vocab_file = os.path.join(save_directory, + self.vocab_files_names['vocab_file']) + else: + vocab_file = save_directory + + with open(self.vocab_file, 'rb') as fin: + proto_str = fin.read() + + with open(vocab_file, 'wb') as writer: + writer.write(proto_str) + + return (vocab_file, ) + + def get_prefix_tokens(self): + prefix_tokens = [self.get_command('[gMASK]'), self.get_command('sop')] + return prefix_tokens + + def build_inputs_with_special_tokens( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A BERT sequence has the following format: + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + prefix_tokens = self.get_prefix_tokens() + token_ids_0 = prefix_tokens + token_ids_0 + if token_ids_1 is not None: + token_ids_0 = token_ids_0 + token_ids_1 + [ + self.get_command('') + ] + return token_ids_0 + + def _pad( + self, + encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], + max_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + ) -> dict: + """ + Pad encoded inputs (on left/right and up to predefined length or max length in the batch) + Args: + encoded_inputs: + Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). + max_length: maximum length of the returned list and optionally padding length (see below). + Will truncate by taking into account the special tokens. + padding_strategy: PaddingStrategy to use for padding. + - PaddingStrategy.LONGEST Pad to the longest sequence in the batch + - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) + - PaddingStrategy.DO_NOT_PAD: Do not pad + The tokenizer padding sides are defined in self.padding_side: + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + return_attention_mask: + (optional) Set to False to avoid returning attention mask (default: set to model specifics) + """ + # Load from model defaults + assert self.padding_side == 'left' + + required_input = encoded_inputs[self.model_input_names[0]] + seq_length = len(required_input) + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = len(required_input) + + if max_length is not None and pad_to_multiple_of is not None and ( + max_length % pad_to_multiple_of != 0): + max_length = ( + (max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len( + required_input) != max_length + + # Initialize attention mask if not present. + if 'attention_mask' not in encoded_inputs: + encoded_inputs['attention_mask'] = [1] * seq_length + + if 'position_ids' not in encoded_inputs: + encoded_inputs['position_ids'] = list(range(seq_length)) + + if needs_to_be_padded: + difference = max_length - len(required_input) + + if 'attention_mask' in encoded_inputs: + encoded_inputs['attention_mask'] = [ + 0 + ] * difference + encoded_inputs['attention_mask'] + if 'position_ids' in encoded_inputs: + encoded_inputs['position_ids'] = [ + 0 + ] * difference + encoded_inputs['position_ids'] + encoded_inputs[self.model_input_names[ + 0]] = [self.pad_token_id] * difference + required_input + + return encoded_inputs diff --git a/modelscope/pipelines/nlp/text_generation_pipeline.py b/modelscope/pipelines/nlp/text_generation_pipeline.py index cfc3645d..a0e8a0ee 100644 --- a/modelscope/pipelines/nlp/text_generation_pipeline.py +++ b/modelscope/pipelines/nlp/text_generation_pipeline.py @@ -1,5 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +# Copyright (c) 2022 Zhipu.AI import os from typing import Any, Dict, Optional, Union @@ -17,7 +18,10 @@ from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.hub import Config, read_config from modelscope.utils.streaming_output import PipelineStreamingOutputMixin -__all__ = ['TextGenerationPipeline', 'TextGenerationT5Pipeline'] +__all__ = [ + 'TextGenerationPipeline', 'TextGenerationT5Pipeline', + 'ChatGLM6bTextGenerationPipeline', 'ChatGLM6bV2TextGenerationPipeline' +] @PIPELINES.register_module( @@ -177,3 +181,71 @@ class TextGenerationT5Pipeline(TextGenerationPipeline): with torch.no_grad(): return self.model.generate(**inputs, **forward_params) + + +@PIPELINES.register_module( + group_key=Tasks.chat, module_name='chatglm6b-text-generation') +class ChatGLM6bTextGenerationPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + quantization_bit=None, + use_bf16=False, + **kwargs): + from modelscope.models.nlp.chatglm.text_generation import ChatGLMForConditionalGeneration + model = ChatGLMForConditionalGeneration(model) if isinstance( + model, str) else model + if quantization_bit is not None: + model = model.quantize(quantization_bit) + if use_bf16: + model = model.bfloat16() + self.model = model + self.model.eval() + + super().__init__(model=model, **kwargs) + + def preprocess(self, inputs, **preprocess_params) -> Dict[str, Any]: + return inputs + + # define the forward pass + def forward(self, inputs: Dict, **forward_params) -> Dict[str, Any]: + return self.model.chat(inputs) + + # format the outputs from pipeline + def postprocess(self, input, **kwargs) -> Dict[str, Any]: + return input + + +@PIPELINES.register_module( + group_key=Tasks.chat, module_name='chatglm2_6b-text-generation') +class ChatGLM6bV2TextGenerationPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + quantization_bit=None, + use_bf16=False, + **kwargs): + from modelscope.models.nlp import ChatGLM2ForConditionalGeneration, ChatGLM2Tokenizer + model = ChatGLM2ForConditionalGeneration(model) if isinstance( + model, str) else model + if quantization_bit is not None: + model = model.quantize(quantization_bit) + if use_bf16: + model = model.bfloat16() + self.model = model + self.model.eval() + self.tokenizer = ChatGLM2Tokenizer.from_pretrained( + self.model.model_dir) + + super().__init__(model=model, **kwargs) + + def preprocess(self, inputs, **preprocess_params) -> Dict[str, Any]: + return inputs + + # define the forward pass + def forward(self, inputs: Dict, **forward_params) -> Dict[str, Any]: + return self.model.chat(self.tokenizer, inputs['text']) + + # format the outputs from pipeline + def postprocess(self, input, **kwargs) -> Dict[str, Any]: + return input diff --git a/modelscope/swift/__init__.py b/modelscope/swift/__init__.py new file mode 100644 index 00000000..bd8ea75e --- /dev/null +++ b/modelscope/swift/__init__.py @@ -0,0 +1,38 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .optimizers.child_tuning_adamw_optimizer import calculate_fisher, ChildTuningAdamW + from .adapter import Adapter, AdapterConfig, AdapterModule + from .lora import LoRA, LoRAConfig, Linear, MergedLinear, Embedding, Conv2d + from .prompt import Prompt, PromptConfig, PromptModule + from .control_sd_lora import ControlLoRACrossAttnProcessor, ControlLoRACrossAttnProcessorV2, ControlLoRATuner + from .base import SwiftConfig, Swift +else: + _import_structure = { + 'optimizers.child_tuning_adamw_optimizer': + ['calculate_fisher', 'ChildTuningAdamW'], + 'adapter': ['Adapter', 'AdapterConfig', 'AdapterModule'], + 'lora': [ + 'LoRA', 'LoRAConfig', 'Linear', 'MergedLinear', 'Embedding', + 'Conv2d' + ], + 'prompt': ['Prompt', 'PromptConfig', 'PromptModule'], + 'control_sd_lora': [ + 'ControlLoRACrossAttnProcessor', 'ControlLoRACrossAttnProcessorV2', + 'ControlLoRATuner' + ], + 'base': ['SwiftConfig', 'Swift'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/swift/adapter.py b/modelscope/swift/adapter.py new file mode 100644 index 00000000..d7366119 --- /dev/null +++ b/modelscope/swift/adapter.py @@ -0,0 +1,195 @@ +import inspect +import os +import re +import types +from dataclasses import dataclass, field +from typing import Union + +import torch +from torch import nn + +from modelscope import snapshot_download +from modelscope.utils.constant import ModelFile +from .base import SwiftConfig + + +@dataclass +class AdapterConfig(SwiftConfig): + """ + The configuration class for the adapter module. + + Adapters project input tokens by an MLP layer. + 'Parameter-Efficient Transfer Learning for NLP' by Houlsby et al.(2019) + See http://arxiv.org/abs/1902.00751 + + Args: + dim: The dimension of the hidden states + module_name: The feedforward module to be replaced, in regex format + hidden_pos: The position of the hidden state to passed into the adapter, can be int (args) or str (kwargs) + method_name: The method to be replaced, default to replace the forward method + adapter_length: The length of the adapter length (intermediate length) + act_layer: The activation layer of the adapter + only_adapter_trainable: Whether to train only adapters + pretrained_weights: The pretrained adapter weights. + Can be a local dir, local file, or a model id from modelscope + """ + + dim: int = field(metadata={'help': 'The dimension of the hidden states'}) + + module_name: str = field( + metadata={ + 'help': 'The feedforward module to be replaced, in regex format' + }) + + hidden_pos: Union[str, int] = field( + metadata={ + 'help': + 'The position of the hidden state to passed into the adapter, can be int (args) or str (kwargs)' + }) + + method_name: str = field( + default='forward', + metadata={ + 'help': + 'The method to be replaced, default to replace the forward method' + }) + + adapter_length: int = field( + default=128, + metadata={ + 'help': 'The length of the adapter length (intermediate length)' + }) + + act_layer: nn.Module = field( + default=nn.GELU, + metadata={'help': 'The activation layer of the adapter'}) + + only_adapter_trainable: bool = field( + default=True, metadata={'help': 'Whether to train only adapters'}) + + pretrained_weights: str = field( + default=None, + metadata={ + 'help': + 'The pretrained adapter weights. Can be a local dir, local file, or a model id from modelscope' + }) + + +class Adapter: + + @staticmethod + def prepare_model(model: nn.Module, config: AdapterConfig): + module_keys = [key for key, _ in model.named_modules()] + + for module_key in module_keys: + if re.fullmatch(config.module_name, module_key): # noqa + module = model.get_submodule(module_key) + + def _forward(self, *args, **kwargs): + args = self.forward_origin(*args, **kwargs) + if isinstance(args, (tuple, list, dict)): + if isinstance(config.hidden_pos, int): + return args[0:config.hidden_pos] + args[ + config.hidden_pos] + getattr(self, 'adapter')(args[config.hidden_pos]) \ + + args[config.hidden_pos + 1:] # noqa + else: + kwargs[config.hidden_pos] = args[ + config.hidden_pos] + getattr(self, 'adapter')( + args[config.hidden_pos]) + elif isinstance(args, torch.Tensor): + args = getattr(self, 'adapter')(args) + return args + + def _feed_forward_chunk(self, attention_output): + return _forward(self, attention_output) + + module.forward_origin = getattr(module, config.method_name) + num_args_in_forward_chunk_fn = len( + inspect.signature(module.forward_origin).parameters) + if config.method_name == 'feed_forward_chunk' and num_args_in_forward_chunk_fn == 1: + setattr(module, config.method_name, + types.MethodType(_feed_forward_chunk, module)) + else: + setattr(module, config.method_name, + types.MethodType(_forward, module)) + adapter_module = AdapterModule(config.dim, + config.adapter_length, + config.act_layer) + setattr(module, 'adapter', adapter_module) + + if config.only_adapter_trainable: + for n, p in model.named_parameters(): + if 'adapter' not in n: + p.requires_grad = False + + def state_dict_hook(module, destination, prefix, local_metadata): + return { + key: value + for key, value in destination.items() if 'adapter' in key + } + + model.state_dict_hook_handle = model._register_state_dict_hook( + state_dict_hook) + + def load_state_dict(self, state_dict, strict=True): + return self.load_state_dict_origin(state_dict, False) + + model.load_state_dict_origin = model.load_state_dict + model.load_state_dict = types.MethodType(load_state_dict, model) + + if config.pretrained_weights is not None: + if not os.path.exists(config.pretrained_weights): + model_dir = snapshot_download(config.pretrained_weights) + pretrained_weights = os.path.join( + model_dir, ModelFile.TORCH_MODEL_BIN_FILE) + elif os.path.isfile(config.pretrained_weights): + pretrained_weights = config.pretrained_weights + else: + pretrained_weights = os.path.join( + config.pretrained_weights, ModelFile.TORCH_MODEL_BIN_FILE) + model.load_state_dict(torch.load(pretrained_weights)) + return model + + +class AdapterModule(nn.Module): + """The implementation of adapter tuning method. + + Adapters project input tokens by an MLP layer. + 'Parameter-Efficient Transfer Learning for NLP' by Houlsby et al.(2019) + See http://arxiv.org/abs/1902.00751 + + Attributes: + dim: An integer indicating the embedding dimension. + adapter_length: An integer indicating the length of adapter tuning. + """ + + def __init__( + self, + dim, + adapter_length=None, + act_layer=nn.GELU, + ): + super(AdapterModule, self).__init__() + self.dim = dim + self.adapter_length = adapter_length + # self.adapter_type = adapter_type + self.ln1 = nn.Linear(dim, adapter_length) + self.activate = act_layer() + self.ln2 = nn.Linear(adapter_length, dim) + self.init_weights() + + def init_weights(self): + + def _init_weights(m): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.normal_(m.bias, std=1e-6) + + self.apply(_init_weights) + + def forward(self, x, identity=None): + out = self.ln2(self.activate(self.ln1(x))) + if identity is None: + identity = x + out = identity + out + return out diff --git a/modelscope/swift/base.py b/modelscope/swift/base.py new file mode 100644 index 00000000..441521ca --- /dev/null +++ b/modelscope/swift/base.py @@ -0,0 +1,31 @@ +from dataclasses import dataclass + + +@dataclass +class SwiftConfig: + pass + + +class Swift: + + @staticmethod + def prepare_model(model, config: SwiftConfig): + """Prepare the module and returns the new module. + + Args: + model: The model to tune. + config: The config of the tuner. + + Returns: + The tuned model. + """ + from .lora import LoRA, LoRAConfig + from .adapter import Adapter, AdapterConfig + from .prompt import Prompt, PromptConfig + if isinstance(config, LoRAConfig): + return LoRA.prepare_model(model, config) + elif isinstance(config, AdapterConfig): + return Adapter.prepare_model(model, config) + elif isinstance(config, PromptConfig): + return Prompt.prepare_model(model, config) + return None diff --git a/modelscope/tuners/control_sd_lora.py b/modelscope/swift/control_sd_lora.py similarity index 100% rename from modelscope/tuners/control_sd_lora.py rename to modelscope/swift/control_sd_lora.py diff --git a/modelscope/tuners/lora.py b/modelscope/swift/lora.py similarity index 76% rename from modelscope/tuners/lora.py rename to modelscope/swift/lora.py index ba1e92e1..3c0be6ba 100644 --- a/modelscope/tuners/lora.py +++ b/modelscope/swift/lora.py @@ -4,93 +4,148 @@ import logging import math import os.path +import re import types +from dataclasses import dataclass, field from typing import Dict, List import torch import torch.nn as nn import torch.nn.functional as F +from modelscope import snapshot_download +from modelscope.utils.constant import ModelFile +from .base import SwiftConfig + logger = logging.getLogger(__name__) -class LoRATuner: +@dataclass +class LoRAConfig(SwiftConfig): + """ + The configuration class for the loRA module. + + Args: + rank: The rank of the LoRA module + replace_modules: The modules to be replaced by LoRA, can be the end of the module name or a regex string + lora_alpha: The factor to add the lora weights + lora_dropout: The dropout rate of the lora module + merge_weights: Whether to merge weights when validating + use_merged_linear: Whether to replace with merged linear layer + enable_lora: The modules need to be turned on when using the merged linear layer + fan_in_fan_out: Set this to True if the layer to replace stores weight like (fan_in, fan_out) + bias: Bias type. Values ca be "none", "all" or "lora_only" + only_lora_trainable: Whether to train only lora + pretrained_weights: The pretrained lora weights. + Can be a local dir, local file, or a model id from modelscope + """ + + rank: int = field( + default=6, metadata={'help': 'The rank of the LoRA module'}) + replace_modules: List = field( + default=None, + metadata={ + 'help': + 'The modules to be replaced by LoRA, can be the end of the module name or a regex string' + }) + lora_alpha: float = field( + default=1., metadata={'help': 'The factor to add the lora weights'}) + lora_dropout: float = field( + default=0., metadata={'help': 'The dropout rate of the lora module'}) + merge_weights: bool = field( + default=True, + metadata={'help': 'Whether to merge weights when validating'}) + use_merged_linear: bool = field( + default=False, + metadata={'help': 'Whether to replace with merged linear layer'}) + enable_lora: List = field( + default=None, + metadata={ + 'help': + 'The modules need to be turned on when using the merged linear layer' + }) + fan_in_fan_out: bool = field( + default=False, + metadata={ + 'help': + 'Set this to True if the layer to replace stores weight like (fan_in, fan_out)' + }) + bias: str = field( + default='none', + metadata={ + 'help': 'Bias type. Values ca be "none", "all" or "lora_only"' + }) + only_lora_trainable: bool = field( + default=True, metadata={'help': 'Whether to train only lora'}) + pretrained_weights: str = field( + default=None, + metadata={ + 'help': + 'The pretrained lora weights. Can be a local dir, local file, or a model id from modelscope' + }) + + +class LoRA: @staticmethod - def tune(model: nn.Module, - rank=6, - replace_modules=None, - lora_alpha=1., - lora_dropout=0., - merge_weights=True, - fan_in_fan_out=False, - bias='none', - pretrained_tuner=None): - """Tune a model with lora. + def prepare_model(model: nn.Module, config: LoRAConfig): + """Tune a model with LoRA. Args: - model: The torch.nn.Module containing the target module to be patched. - rank: The lora rank. - replace_modules: The module names to be replaced, the replacing strategy is `end with`. - lora_alpha: The alpha value for lora module. - lora_dropout: The dropout value for lora module. - merge_weights: If merge_weights set to True, when the module turns to `eval`, the lora weights - will be added into the origin weight to reduce calculation. - fan_in_fan_out: Set this to True if the layer to replace stores weight like (fan_in, fan_out). - bias: The grad strategy for bias, can be `none`, 'all' or 'lora_only'. - pretrained_tuner: The pretrained file of lora. + config: The LoRAConfig instance. Returns: The lora modules """ - modules = LoRATuner._dynamic_patch_lora( + LoRA._dynamic_patch_lora( model, - replace_modules=replace_modules, - r=rank, - lora_alpha=lora_alpha, - lora_dropout=lora_dropout, - merge_weights=merge_weights, - fan_in_fan_out=fan_in_fan_out) + replace_modules=config.replace_modules, + r=config.rank, + lora_alpha=config.lora_alpha, + lora_dropout=config.lora_dropout, + merge_weights=config.merge_weights, + use_merged_linear=config.use_merged_linear, + enable_lora=config.enable_lora, + fan_in_fan_out=config.fan_in_fan_out) - mark_only_lora_as_trainable(model, bias) + if config.only_lora_trainable: + mark_only_lora_as_trainable(model, config.bias) def state_dict_hook(module, destination, prefix, local_metadata): - return lora_state_dict(destination, bias) + return lora_state_dict(destination, config.bias) model.state_dict_hook_handle = model._register_state_dict_hook( state_dict_hook) - def warning_hook(module, incompatible_keys): - logger.info( - f'The {module.__class__.__name__} module has unmatched keys: {incompatible_keys},' - f'this is converted to a notice with respect to LoRA') - for ik in incompatible_keys: - ik.clear() + def load_state_dict(self, state_dict, strict=True): + return self.load_state_dict_origin(state_dict, False) - if hasattr(model, 'register_load_state_dict_post_hook'): - model.load_state_dict_hook_handle = model.register_load_state_dict_post_hook( - warning_hook) - else: + model.load_state_dict_origin = model.load_state_dict + model.load_state_dict = types.MethodType(load_state_dict, model) - def load_state_dict(self, state_dict, strict=True): - return self.load_state_dict_origin(state_dict, False) + if config.pretrained_weights is not None: + if not os.path.exists(config.pretrained_weights): + model_dir = snapshot_download(config.pretrained_weights) + pretrained_weights = os.path.join( + model_dir, ModelFile.TORCH_MODEL_BIN_FILE) + elif os.path.isfile(config.pretrained_weights): + pretrained_weights = config.pretrained_weights + else: + pretrained_weights = os.path.join( + config.pretrained_weights, ModelFile.TORCH_MODEL_BIN_FILE) + model.load_state_dict(torch.load(pretrained_weights)) - model.load_state_dict_origin = model.load_state_dict - model.load_state_dict = types.MethodType(load_state_dict, model) - - if pretrained_tuner is not None and os.path.isfile(pretrained_tuner): - logger.info(f'Loading LoRA weights from file: {pretrained_tuner}') - model.load_state_dict(torch.load(pretrained_tuner)) - - return modules + return model @staticmethod - def _dynamic_patch_lora(model, replace_modules, **kwargs): + def _dynamic_patch_lora(model, replace_modules, use_merged_linear, + **kwargs): """Dynamic patch lora to model Args: model: The torch.nn.Module containing the target module to be patched. replace_modules: The module names to be replaced, the replacing strategy is `end with`. + use_merged_linear: Whether to replace with merged linear layer **kwargs: The arguments passed from `tune` which are needed by lora. Returns: @@ -103,8 +158,13 @@ class LoRATuner: replace_modules = [replace_modules] for module_key in module_keys: - if any([module_key.endswith(name) - for name in replace_modules]): # noqa + if isinstance(replace_modules, str): + target_module_found = re.fullmatch(replace_modules, module_key) + else: + target_module_found = any( + module_key.endswith(target_key) + for target_key in replace_modules) + if target_module_found: # noqa parts = module_key.split('.') module = model.get_submodule('.'.join(parts[:-1])) sub_module = model.get_submodule(module_key) @@ -112,11 +172,19 @@ class LoRATuner: lora_module = None if isinstance(sub_module, torch.nn.Linear): - lora_module = Linear( - sub_module.in_features, - sub_module.out_features, - bias=sub_module.bias is not None, - **kwargs) + if use_merged_linear: + lora_module = MergedLinear( + sub_module.in_features, + sub_module.out_features, + bias=sub_module.bias is not None, + **kwargs) + else: + kwargs.pop('enable_lora', None) + lora_module = Linear( + sub_module.in_features, + sub_module.out_features, + bias=sub_module.bias is not None, + **kwargs) elif isinstance(sub_module, torch.nn.Conv2d): kwargs.pop('fan_in_fan_out', None) lora_module = Conv2d( @@ -140,9 +208,13 @@ class LoRATuner: return modules @staticmethod - def unpatch_lora(model, replace_modules): + def unpatch_lora(model, config: LoRAConfig): """Unpatch lora modules and merge the weights to original modules. + LoRA constructs an additional layer with low-rank decomposition matrices of the weights in the network. + 'LoRA: Low-Rank Adaptation of Large Language Models' by Hu et al.(2021) + See https://arxiv.org/abs/2106.09685 + Args: model: The model called with `tune` function. replace_modules: The module names to be replaced, the replacing strategy is `end with`. @@ -152,13 +224,17 @@ class LoRATuner: """ modules = [] module_keys = [key for key, _ in model.named_modules()] - assert isinstance(replace_modules, (str, list)) - if isinstance(replace_modules, str): - replace_modules = [replace_modules] + assert isinstance(config.replace_modules, (str, list)) + replace_modules = config.replace_modules for module_key in module_keys: - if any([module_key.endswith(name) - for name in replace_modules]): # noqa + if isinstance(replace_modules, str): + target_module_found = re.fullmatch(replace_modules, module_key) + else: + target_module_found = any( + module_key.endswith(target_key) + for target_key in replace_modules) + if target_module_found: # noqa parts = module_key.split('.') module = model.get_submodule('.'.join(parts[:-1])) sub_module = model.get_submodule(module_key) diff --git a/modelscope/swift/optimizers/__init__.py b/modelscope/swift/optimizers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/trainers/optimizer/child_tuning_adamw_optimizer.py b/modelscope/swift/optimizers/child_tuning_adamw_optimizer.py similarity index 97% rename from modelscope/trainers/optimizer/child_tuning_adamw_optimizer.py rename to modelscope/swift/optimizers/child_tuning_adamw_optimizer.py index 74215801..02b459fa 100644 --- a/modelscope/trainers/optimizer/child_tuning_adamw_optimizer.py +++ b/modelscope/swift/optimizers/child_tuning_adamw_optimizer.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -import types from typing import Callable, Iterable, Tuple import numpy as np @@ -22,7 +21,6 @@ from torch.distributions.bernoulli import Bernoulli from torch.optim import Optimizer from modelscope.utils.logger import get_logger -from .builder import OPTIMIZERS, default_group logger = get_logger() @@ -72,8 +70,6 @@ def calculate_fisher(model: torch.nn.Module, return gradient_mask -@OPTIMIZERS.register_module( - group_key=default_group, module_name='ChildTuningAdamW') class ChildTuningAdamW(Optimizer): def __init__(self, diff --git a/modelscope/swift/prompt.py b/modelscope/swift/prompt.py new file mode 100644 index 00000000..715f6942 --- /dev/null +++ b/modelscope/swift/prompt.py @@ -0,0 +1,214 @@ +import os +import re +import types +from dataclasses import dataclass, field +from typing import Union + +import torch +from torch import nn + +from modelscope import snapshot_download +from modelscope.utils.constant import ModelFile +from .base import SwiftConfig + + +@dataclass +class PromptConfig(SwiftConfig): + """ + The configuration class for the prompt module. + + Visual prompt tuning (VPT) is proposed to initialize tunable prompt tokens + and prepend to the original tokens in the first layer or multiple layers. + 'Visual Prompt Tuning' by Jia et al.(2022) + See https://arxiv.org/abs/2203.12119 + + Here we apply the VPT to other fields. + + Args: + dim: The dimension of the hidden states + module_layer_name: The layer module to be replaced, in regex format + embedding_pos: The position of the embedding tensor + attention_mask_pos: The position of the attention mask + attention_mask_value: The value to pad to the attention mask + prompt_length: The length of the prompt tokens + only_prompt_trainable: Whether to train only prompt + attach_front: When set to True, prompt is attached in front of the embedding + pretrained_weights: The pretrained prompt weights. Can be a local dir, local file, + or a model id from modelscope + """ + + dim: int = field(metadata={'help': 'The dimension of the hidden states'}) + + module_layer_name: str = field( + metadata={'help': 'The layer module to be replaced, in regex format'}) + + embedding_pos: Union[str, int] = field( + metadata={'help': 'The position of the embedding tensor'}) + + attention_mask_pos: Union[str, int] = field( + default=None, metadata={'help': 'The position of the attention mask'}) + + attention_mask_value: Union[float, int, bool] = field( + default=0., + metadata={'help': 'The value to pad to the attention mask'}) + + prompt_length: int = field( + default=16, metadata={'help': 'The length of the prompt tokens'}) + + only_prompt_trainable: bool = field( + default=True, metadata={'help': 'Whether to train only prompt'}) + + attach_front: bool = field( + default=True, + metadata={ + 'help': + 'When set to True, prompt is attached in front of the embedding' + }) + + pretrained_weights: str = field( + default=None, + metadata={ + 'help': + 'The pretrained prompt weights. Can be a local dir, local file, or a model id from modelscope' + }) + + +class Prompt: + + @staticmethod + def prepare_model(model: nn.Module, config: PromptConfig): + module_keys = [key for key, _ in model.named_modules()] + for module_key in module_keys: + if re.fullmatch(config.module_layer_name, module_key): # noqa + module = model.get_submodule(module_key) + + def _forward(self, *args, **kwargs): + if isinstance(config.embedding_pos, int): + input_embedding = args[config.embedding_pos] + else: + input_embedding = kwargs[config.embedding_pos] + + input_embedding = getattr( + self, 'prompt').forward(input_embedding) + if isinstance(config.embedding_pos, int): + args = type(args)( + args[0:config.embedding_pos] + (input_embedding, ) + + args[config.embedding_pos + 1:]) + else: + kwargs[config.embedding_pos] = input_embedding + + if config.attention_mask_pos: + attention_mask = None + if isinstance(config.attention_mask_pos, int): + attention_mask = args[config.attention_mask_pos] + elif isinstance(config.attention_mask_pos, str): + attention_mask = kwargs[config.attention_mask_pos] + + if attention_mask is not None: + attention_mask = getattr( + self, + 'prompt').patch_attention_mask(attention_mask) + if isinstance(config.attention_mask_pos, int): + args = type(args)( + args[0:config.attention_mask_pos] + + (attention_mask, ) + + args[config.attention_mask_pos + 1:]) + else: + kwargs[config.attention_mask_pos] = attention_mask + + return self.forward_origin(*args, **kwargs) + + module.forward_origin = module.forward + module.forward = types.MethodType(_forward, module) + prompt_module = PromptModule(config.dim, + int(module_key.rsplit('.')[-1]), + config.prompt_length, + config.attention_mask_value, + config.attach_front) + setattr(module, 'prompt', prompt_module) + + if config.only_prompt_trainable: + for n, p in model.named_parameters(): + if 'prompt' not in n: + p.requires_grad = False + + def state_dict_hook(module, destination, prefix, local_metadata): + return { + key: value + for key, value in destination.items() if 'prompt' in key + } + + model.state_dict_hook_handle = model._register_state_dict_hook( + state_dict_hook) + + def load_state_dict(self, state_dict, strict=True): + return self.load_state_dict_origin(state_dict, False) + + model.load_state_dict_origin = model.load_state_dict + model.load_state_dict = types.MethodType(load_state_dict, model) + + if config.pretrained_weights is not None: + if not os.path.exists(config.pretrained_weights): + model_dir = snapshot_download(config.pretrained_weights) + pretrained_weights = os.path.join( + model_dir, ModelFile.TORCH_MODEL_BIN_FILE) + elif os.path.isfile(config.pretrained_weights): + pretrained_weights = config.pretrained_weights + else: + pretrained_weights = os.path.join( + config.pretrained_weights, ModelFile.TORCH_MODEL_BIN_FILE) + model.load_state_dict(torch.load(pretrained_weights)) + return model + + +class PromptModule(nn.Module): + """The implementation of vision prompt tuning method. + + Visual prompt tuning (VPT) is proposed to initialize tunable prompt tokens + and prepend to the original tokens in the first layer or multiple layers. + 'Visual Prompt Tuning' by Jia et al.(2022) + See https://arxiv.org/abs/2203.12119 + + Attributes: + dim: An integer indicating the embedding dimension. + layer_num: An integer indicating number of layers. + prompt_length: An integer indicating the length of vision prompt tuning. + """ + + def __init__(self, + dim, + layer_num, + prompt_length=None, + mask_values=0., + attach_front=True): + super(PromptModule, self).__init__() + self.dim = dim + self.layer_num = layer_num + self.prompt_length = prompt_length + self.mask_values = mask_values + self.attach_front = attach_front + + self.prompt_token = nn.Parameter(torch.zeros(1, prompt_length, dim)) + nn.init.xavier_uniform_(self.prompt_token) + + def forward(self, x): + prompt_token = self.prompt_token.expand(x.shape[0], -1, -1) + + if self.layer_num == 0: + if self.attach_front: + x = torch.cat((prompt_token, x), dim=1) + else: + x = torch.cat((x, prompt_token), dim=1) + else: + if self.attach_front: + x = torch.cat((prompt_token, x[:, self.prompt_length:, :]), + dim=1) + else: + x = torch.cat((x[:, :-self.prompt_length, :], prompt_token), + dim=1) + return x + + def patch_attention_mask(self, m): + prefix_attention_mask = torch.full((*m.shape[:-1], self.prompt_length), + self.mask_values).to(m.device) + return torch.cat((prefix_attention_mask, m), dim=-1) diff --git a/modelscope/swift/sd_lora.py b/modelscope/swift/sd_lora.py new file mode 100644 index 00000000..feff05f4 --- /dev/null +++ b/modelscope/swift/sd_lora.py @@ -0,0 +1,218 @@ +# Copyright 2023-2024 The Alibaba Fundamental Vision Team Authors. All rights reserved. +# The implementation is adopted from HighCWu, +# made pubicly available under the Apache License 2.0 License at https://github.com/HighCWu/ControlLoRA +import os +from dataclasses import dataclass +from typing import List, Tuple, Union + +import torch +import torch.nn as nn +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.cross_attention import CrossAttention, LoRALinearLayer +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils.outputs import BaseOutput + + +@dataclass +class TunerOutput(BaseOutput): + lora_states: Tuple[torch.FloatTensor] + + +class LoRACrossAttnProcessor(nn.Module): + """ The implementation of lora attention module. + """ + + def __init__(self, + hidden_size, + cross_attention_dim=None, + rank=4, + post_add=False, + key_states_skipped=False, + value_states_skipped=False, + output_states_skipped=False): + """ Initialize a lora attn instance. + Args: + hidden_size (`int`): The number of channels in embedding. + cross_attention_dim (`int`, *optional*): + The number of channels in the hidden_states. If not given, defaults to `hidden_size`. + rank (`int`, *optional*, defaults to 4): The number of rank of lora. + post_add (`bool`, *optional*, defaults to False): Set to `True`, conduct weighted + adding operation after lora. + key_states_skipped (`bool`, *optional*, defaults to False): + Set to `True` for skip to perform lora on key value. + value_states_skipped (`bool`, *optional*, defaults to False): + Set to `True` for skip to perform lora on value. + output_states_skipped (`bool`, *optional*, defaults to False): + Set to `True` for skip to perform lora on output value. + """ + super().__init__() + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.rank = rank + self.post_add = post_add + + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) + if not key_states_skipped: + self.to_k_lora = LoRALinearLayer( + hidden_size if post_add else + (cross_attention_dim or hidden_size), hidden_size, rank) + if not value_states_skipped: + self.to_v_lora = LoRALinearLayer( + hidden_size if post_add else + (cross_attention_dim or hidden_size), hidden_size, rank) + if not output_states_skipped: + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) + + self.key_states_skipped: bool = key_states_skipped + self.value_states_skipped: bool = value_states_skipped + self.output_states_skipped: bool = output_states_skipped + + def skip_key_states(self, is_skipped: bool = True): + if not is_skipped: + assert hasattr(self, 'to_k_lora') + self.key_states_skipped = is_skipped + + def skip_value_states(self, is_skipped: bool = True): + if not is_skipped: + assert hasattr(self, 'to_q_lora') + self.value_states_skipped = is_skipped + + def skip_output_states(self, is_skipped: bool = True): + if not is_skipped: + assert hasattr(self, 'to_out_lora') + self.output_states_skipped = is_skipped + + def __call__(self, + attn: CrossAttention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + scale=1.0): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask( + attention_mask=attention_mask, + target_length=sequence_length, + batch_size=batch_size) + + query = attn.to_q(hidden_states) + query = query + scale * self.to_q_lora( + query if self.post_add else hidden_states) + query = attn.head_to_batch_dim(query) + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + + key = attn.to_k(encoder_hidden_states) + if not self.key_states_skipped: + key = key + scale * self.to_k_lora( + key if self.post_add else encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + if not self.value_states_skipped: + value = value + scale * self.to_v_lora( + value if self.post_add else encoder_hidden_states) + + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + out = attn.to_out[0](hidden_states) + if not self.output_states_skipped: + out = out + scale * self.to_out_lora( + out if self.post_add else hidden_states) + hidden_states = out + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class LoRATuner(ModelMixin, ConfigMixin): + + @staticmethod + def tune( + model: nn.Module, + tuner_config=None, + pretrained_tuner=None, + ): + tuner = LoRATuner.from_config(tuner_config) + if pretrained_tuner is not None and os.path.exists(pretrained_tuner): + tuner.load_state_dict( + torch.load(pretrained_tuner, map_location='cpu'), strict=True) + tune_layers_list = list( + [list(layer_list) for layer_list in tuner.lora_layers]) + assert hasattr(model, 'unet') + unet = model.unet + tuner.to(unet.device) + tune_attn_procs = tuner.set_tune_layers(unet, tune_layers_list) + unet.set_attn_processor(tune_attn_procs) + return tuner + + def set_tune_layers(self, unet, tune_layers_list): + n_ch = len(unet.config.block_out_channels) + control_ids = [i for i in range(n_ch)] + tune_attn_procs = {} + + for name in unet.attn_processors.keys(): + if name.startswith('mid_block'): + control_id = control_ids[-1] + elif name.startswith('up_blocks'): + block_id = int(name[len('up_blocks.')]) + control_id = list(reversed(control_ids))[block_id] + elif name.startswith('down_blocks'): + block_id = int(name[len('down_blocks.')]) + control_id = control_ids[block_id] + + tune_layers = tune_layers_list[control_id] + if len(tune_layers) != 0: + tune_layer = tune_layers.pop(0) + tune_attn_procs[name] = tune_layer + return tune_attn_procs + + @register_to_config + def __init__( + self, + lora_block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + lora_cross_attention_dims: Tuple[List[int]] = ([ + None, 768, None, 768, None, 768, None, 768, None, 768 + ], [None, 768, None, 768, None, 768, None, 768, None, + 768], [None, 768, None, 768, None, 768, None, 768, None, + 768], [None, 768]), + lora_rank: int = 4, + lora_post_add: bool = False, + lora_key_states_skipped: bool = False, + lora_value_states_skipped: bool = False, + lora_output_states_skipped: bool = False, + ): + super().__init__() + + lora_cls = LoRACrossAttnProcessor + + self.lora_layers = nn.ModuleList([]) + + for i, lora_cross_attention_dim in enumerate( + lora_cross_attention_dims): + self.lora_layers.append( + nn.ModuleList([ + lora_cls( + lora_block_out_channels[i], + cross_attention_dim=cross_attention_dim, + rank=lora_rank, + post_add=lora_post_add, + key_states_skipped=lora_key_states_skipped, + value_states_skipped=lora_value_states_skipped, + output_states_skipped=lora_output_states_skipped) + for cross_attention_dim in lora_cross_attention_dim + ])) + + def forward(self) -> Union[TunerOutput, Tuple]: + lora_states_list = [] + tune_layers_list = list( + [list(layer_list) for layer_list in self.lora_layers]) + for tune_list in tune_layers_list: + for tune_layer in tune_list: + lora_states_list.append(tune_layer.to_q_lora.down.weight) + return TunerOutput(lora_states=tuple(lora_states_list)) diff --git a/modelscope/trainers/optimizer/__init__.py b/modelscope/trainers/optimizer/__init__.py index 9962c2c2..cd59c072 100644 --- a/modelscope/trainers/optimizer/__init__.py +++ b/modelscope/trainers/optimizer/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +from modelscope.swift import ChildTuningAdamW from .builder import OPTIMIZERS, build_optimizer -from .child_tuning_adamw_optimizer import ChildTuningAdamW __all__ = ['OPTIMIZERS', 'build_optimizer', 'ChildTuningAdamW'] diff --git a/modelscope/trainers/trainer.py b/modelscope/trainers/trainer.py index fd0fafb8..d8d87826 100644 --- a/modelscope/trainers/trainer.py +++ b/modelscope/trainers/trainer.py @@ -44,6 +44,7 @@ from modelscope.utils.registry import build_from_cfg from modelscope.utils.torch_utils import (compile_model, get_dist_info, get_local_rank, init_dist, is_dist, is_master, set_random_seed) +from ..swift import Swift from .base import BaseTrainer from .builder import TRAINERS from .default_config import merge_cfg, merge_hooks, update_cfg @@ -264,10 +265,7 @@ class EpochBasedTrainer(BaseTrainer): def tune_module(self, efficient_tuners): if efficient_tuners is not None: for tuner in efficient_tuners: - type = tuner.pop('type') - if type == 'lora': - from modelscope.tuners.lora import LoRATuner - LoRATuner.tune(self.model, **tuner) + self.model = Swift.prepare_model(self.model, tuner) def place_model(self): """Place model to device, or to DDP diff --git a/modelscope/utils/__init__.py b/modelscope/utils/__init__.py index e69de29b..7486e137 100644 --- a/modelscope/utils/__init__.py +++ b/modelscope/utils/__init__.py @@ -0,0 +1 @@ +from .hub import create_model_if_not_exist, read_config diff --git a/tests/trainers/test_finetune_sequence_classification.py b/tests/trainers/test_finetune_sequence_classification.py index 797351aa..ef0ccae7 100644 --- a/tests/trainers/test_finetune_sequence_classification.py +++ b/tests/trainers/test_finetune_sequence_classification.py @@ -8,18 +8,17 @@ from modelscope.metainfo import Preprocessors, Trainers from modelscope.models import Model from modelscope.msdatasets import MsDataset from modelscope.pipelines import pipeline +from modelscope.swift.optimizers.child_tuning_adamw_optimizer import \ + calculate_fisher from modelscope.trainers import build_trainer from modelscope.trainers.hooks import Hook from modelscope.trainers.nlp_trainer import (EpochBasedTrainer, NlpEpochBasedTrainer) -from modelscope.trainers.optimizer.child_tuning_adamw_optimizer import \ - calculate_fisher from modelscope.trainers.training_args import TrainingArgs from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.data_utils import to_device from modelscope.utils.regress_test_utils import (MsRegressTool, compare_arguments_nested) -from modelscope.utils.test_utils import test_level class TestFinetuneSequenceClassification(unittest.TestCase): diff --git a/tests/trainers/test_finetune_vision_efficient_tuning_swift.py b/tests/trainers/test_finetune_vision_efficient_tuning_swift.py new file mode 100644 index 00000000..d8733024 --- /dev/null +++ b/tests/trainers/test_finetune_vision_efficient_tuning_swift.py @@ -0,0 +1,164 @@ +# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import os +import shutil +import tempfile +import unittest + +from modelscope.metainfo import Trainers +from modelscope.msdatasets import MsDataset +from modelscope.swift import Swift +from modelscope.swift.adapter import AdapterConfig +from modelscope.swift.lora import LoRAConfig +from modelscope.swift.prompt import PromptConfig +from modelscope.trainers import build_trainer +from modelscope.utils.test_utils import test_level + + +class TestVisionEfficientTuningSwiftTrainer(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + + self.train_dataset = MsDataset.load( + 'foundation_model_evaluation_benchmark', + namespace='damo', + subset_name='OxfordFlowers', + split='train') + + self.eval_dataset = MsDataset.load( + 'foundation_model_evaluation_benchmark', + namespace='damo', + subset_name='OxfordFlowers', + split='eval') + + self.max_epochs = 1 + self.num_classes = 102 + self.tune_length = 10 + + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_vision_efficient_tuning_swift_lora_train(self): + model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-lora' + + def cfg_modify_fn(cfg): + cfg.model.head.num_classes = self.num_classes + cfg.model.finetune = True + cfg.train.max_epochs = self.max_epochs + cfg.train.lr_scheduler.T_max = self.max_epochs + cfg.model.backbone.lora_length = 0 + return cfg + + lora_config = LoRAConfig( + rank=self.tune_length, + replace_modules=['qkv'], + merge_weights=False, + only_lora_trainable=False, + use_merged_linear=True, + enable_lora=[True]) + + kwargs = dict( + model=model_id, + work_dir=self.tmp_dir, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + cfg_modify_fn=cfg_modify_fn, + efficient_tuners=[lora_config]) + + trainer = build_trainer( + name=Trainers.vision_efficient_tuning, default_args=kwargs) + trainer.train() + result = trainer.evaluate() + print(f'Vision-efficient-tuning-lora train output: {result}.') + + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_vision_efficient_tuning_swift_adapter_train(self): + model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-adapter' + + def cfg_modify_fn(cfg): + cfg.model.head.num_classes = self.num_classes + cfg.model.finetune = True + cfg.train.max_epochs = self.max_epochs + cfg.train.lr_scheduler.T_max = self.max_epochs + cfg.model.backbone.adapter_length = 0 + return cfg + + adapter_config = AdapterConfig( + dim=768, + hidden_pos=0, + module_name=r'.*blocks\.\d+\.mlp$', + adapter_length=self.tune_length, + only_adapter_trainable=False) + + kwargs = dict( + model=model_id, + work_dir=self.tmp_dir, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + cfg_modify_fn=cfg_modify_fn, + efficient_tuners=[adapter_config]) + + trainer = build_trainer( + name=Trainers.vision_efficient_tuning, default_args=kwargs) + trainer.train() + result = trainer.evaluate() + print(f'Vision-efficient-tuning-adapter train output: {result}.') + + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_vision_efficient_tuning_swift_prompt_train(self): + model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-prompt' + + def cfg_modify_fn(cfg): + cfg.model.head.num_classes = self.num_classes + cfg.model.finetune = True + cfg.train.max_epochs = self.max_epochs + cfg.train.lr_scheduler.T_max = self.max_epochs + cfg.model.backbone.prompt_length = 0 + return cfg + + prompt_config = PromptConfig( + dim=768, + module_layer_name=r'.*blocks\.\d+$', + embedding_pos=0, + prompt_length=self.tune_length, + only_prompt_trainable=False, + attach_front=False) + + kwargs = dict( + model=model_id, + work_dir=self.tmp_dir, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + cfg_modify_fn=cfg_modify_fn, + efficient_tuners=[prompt_config]) + + trainer = build_trainer( + name=Trainers.vision_efficient_tuning, default_args=kwargs) + trainer.train() + result = trainer.evaluate() + print(f'Vision-efficient-tuning-prompt train output: {result}.') + + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/tuners/test_adapter.py b/tests/tuners/test_adapter.py new file mode 100644 index 00000000..a110591a --- /dev/null +++ b/tests/tuners/test_adapter.py @@ -0,0 +1,81 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest + +import numpy as np +import torch + +from modelscope import read_config +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models.base import Model +from modelscope.msdatasets import MsDataset +from modelscope.pipelines import pipeline +from modelscope.swift import Swift +from modelscope.swift.adapter import AdapterConfig +from modelscope.trainers import build_trainer +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.test_utils import test_level + + +class TestAdapter(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + @unittest.skipUnless(test_level() >= 0, 'skip in this level') + def test_adapter_smoke_test(self): + dataset = MsDataset.load( + 'clue', subset_name='afqmc', + split='train').to_hf_dataset().select(range(2)) + + model_dir = snapshot_download( + 'damo/nlp_structbert_sentence-similarity_chinese-tiny') + model = Model.from_pretrained(model_dir, adv_grad_factor=None) + + cfg_file = os.path.join(model_dir, 'configuration.json') + + model_cfg = os.path.join(model_dir, 'config.json') + model_cfg = read_config(model_cfg) + + adapter_config = AdapterConfig( + dim=model_cfg.hidden_size, + module_name=r'.*layer\.\d+$', + method_name='feed_forward_chunk', + hidden_pos=0) + model = Swift.prepare_model(model, adapter_config) + kwargs = dict( + model=model, + cfg_file=cfg_file, + train_dataset=dataset, + eval_dataset=dataset, + work_dir=self.tmp_dir) + + trainer = build_trainer(default_args=kwargs) + trainer.train() + output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR) + + def pipeline_sentence_similarity(model_dir): + model = Model.from_pretrained(model_dir) + adapter_config.pretrained_weights = output_dir + Swift.prepare_model(model, adapter_config) + model.eval() + pipeline_ins = pipeline( + task=Tasks.sentence_similarity, model=model) + return pipeline_ins(input=('test', 'this is a test')) + + output1 = pipeline_sentence_similarity( + 'damo/nlp_structbert_sentence-similarity_chinese-tiny') + print(output1) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/tuners/test_lora.py b/tests/tuners/test_lora.py index 2f52a4d3..b3238dad 100644 --- a/tests/tuners/test_lora.py +++ b/tests/tuners/test_lora.py @@ -11,9 +11,10 @@ from modelscope.hub.snapshot_download import snapshot_download from modelscope.models.base import Model from modelscope.msdatasets import MsDataset from modelscope.pipelines import pipeline +from modelscope.swift import Swift +from modelscope.swift.lora import (Linear, LoRA, LoRAConfig, + mark_only_lora_as_trainable) from modelscope.trainers import build_trainer -from modelscope.tuners.lora import (Linear, LoRATuner, - mark_only_lora_as_trainable) from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.test_utils import test_level @@ -66,22 +67,18 @@ class TestLora(unittest.TestCase): model_dir = snapshot_download( 'damo/nlp_structbert_sentence-similarity_chinese-tiny') - model = Model.from_pretrained( - 'damo/nlp_structbert_sentence-similarity_chinese-tiny', - adv_grad_factor=None) + model = Model.from_pretrained(model_dir, adv_grad_factor=None) cfg_file = os.path.join(model_dir, 'configuration.json') + lora_config = LoRAConfig(replace_modules=['query', 'key', 'value']) + model = Swift.prepare_model(model, lora_config) kwargs = dict( model=model, cfg_file=cfg_file, train_dataset=dataset, eval_dataset=dataset, - work_dir=self.tmp_dir, - efficient_tuners=[{ - 'type': 'lora', - 'replace_modules': ['query', 'key', 'value'] - }]) + work_dir=self.tmp_dir) trainer = build_trainer(default_args=kwargs) trainer.train() @@ -89,7 +86,8 @@ class TestLora(unittest.TestCase): def pipeline_sentence_similarity(model_dir): model = Model.from_pretrained(model_dir) - LoRATuner.tune(model, replace_modules=['query', 'key', 'value']) + lora_config.pretrained_weights = output_dir + Swift.prepare_model(model, lora_config) model.load_state_dict( torch.load(os.path.join(output_dir, 'pytorch_model.bin'))) model.eval() @@ -100,7 +98,7 @@ class TestLora(unittest.TestCase): output1 = pipeline_sentence_similarity( 'damo/nlp_structbert_sentence-similarity_chinese-tiny') - LoRATuner.unpatch_lora(model, ['query', 'key', 'value']) + LoRA.unpatch_lora(model, lora_config) model.save_pretrained( output_dir, save_checkpoint_names='pytorch_model.bin') diff --git a/tests/tuners/test_prompt.py b/tests/tuners/test_prompt.py new file mode 100644 index 00000000..c338162f --- /dev/null +++ b/tests/tuners/test_prompt.py @@ -0,0 +1,83 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest + +import numpy as np +import torch + +from modelscope import read_config +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models.base import Model +from modelscope.msdatasets import MsDataset +from modelscope.pipelines import pipeline +from modelscope.swift import Swift +from modelscope.swift.adapter import AdapterConfig +from modelscope.swift.prompt import PromptConfig +from modelscope.trainers import build_trainer +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.test_utils import test_level + + +class TestPrompt(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + @unittest.skipUnless(test_level() >= 0, 'skip in this level') + def test_prompt_smoke_test(self): + dataset = MsDataset.load( + 'clue', subset_name='afqmc', + split='train').to_hf_dataset().select(range(2)) + + model_dir = snapshot_download( + 'damo/nlp_structbert_sentence-similarity_chinese-tiny') + model = Model.from_pretrained(model_dir, adv_grad_factor=None) + + cfg_file = os.path.join(model_dir, 'configuration.json') + model_cfg = os.path.join(model_dir, 'config.json') + model_cfg = read_config(model_cfg) + + prompt_config = PromptConfig( + dim=model_cfg.hidden_size, + module_layer_name=r'.*layer\.\d+$', + embedding_pos=0, + attention_mask_pos=1) + + model = Swift.prepare_model(model, prompt_config) + + kwargs = dict( + model=model, + cfg_file=cfg_file, + train_dataset=dataset, + eval_dataset=dataset, + work_dir=self.tmp_dir) + + trainer = build_trainer(default_args=kwargs) + trainer.train() + output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR) + + def pipeline_sentence_similarity(model_dir): + model = Model.from_pretrained(model_dir) + prompt_config.pretrained_weights = output_dir + Swift.prepare_model(model, prompt_config) + model.eval() + pipeline_ins = pipeline( + task=Tasks.sentence_similarity, model=model) + return pipeline_ins(input=('test', 'this is a test')) + + output1 = pipeline_sentence_similarity( + 'damo/nlp_structbert_sentence-similarity_chinese-tiny') + print(output1) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/utils/test_ast.py b/tests/utils/test_ast.py index 288c076a..544e75b6 100644 --- a/tests/utils/test_ast.py +++ b/tests/utils/test_ast.py @@ -35,7 +35,7 @@ class AstScaningTest(unittest.TestCase): def test_ast_scaning_class(self): astScaner = AstScanning() pipeline_file = os.path.join(MODELSCOPE_PATH, 'pipelines', 'nlp', - 'text_generation_pipeline.py') + 'fill_mask_pipeline.py') output = astScaner.generate_ast(pipeline_file) self.assertTrue(output['imports'] is not None) self.assertTrue(output['from_imports'] is not None) @@ -45,24 +45,19 @@ class AstScaningTest(unittest.TestCase): self.assertIsInstance(imports, dict) self.assertIsInstance(from_imports, dict) self.assertIsInstance(decorators, list) - self.assertListEqual( - list(set(imports.keys()) - set(['torch', 'os'])), []) - self.assertEqual(len(from_imports.keys()), 11) + self.assertListEqual(list(set(imports.keys()) - set(['numpy'])), []) + self.assertEqual(len(from_imports.keys()), 8) self.assertTrue(from_imports['modelscope.metainfo'] is not None) self.assertEqual(from_imports['modelscope.metainfo'], ['Pipelines']) - self.assertEqual( - decorators, - [('PIPELINES', 'text-generation', 'text-generation'), - ('PIPELINES', 'text2text-generation', 'translation_en_to_de'), - ('PIPELINES', 'text2text-generation', 'translation_en_to_ro'), - ('PIPELINES', 'text2text-generation', 'translation_en_to_fr'), - ('PIPELINES', 'text2text-generation', 'text2text-generation')]) + self.assertEqual(decorators, + [('PIPELINES', 'fill-mask', 'fill-mask'), + ('PIPELINES', 'fill-mask', 'fill-mask-ponet')]) def test_files_scaning_method(self): fileScaner = FilesAstScanning() # case of pass in files directly pipeline_file = os.path.join(MODELSCOPE_PATH, 'pipelines', 'nlp', - 'text_generation_pipeline.py') + 'fill_mask_pipeline.py') file_list = [pipeline_file] output = fileScaner.get_files_scan_results(file_list) self.assertTrue(output[INDEX_KEY] is not None)