From bd2f70a6eb68eba988411d48a5e5609715f6b983 Mon Sep 17 00:00:00 2001 From: "lukeming.lkm" Date: Wed, 2 Aug 2023 14:05:13 +0800 Subject: [PATCH] add quantization in qwen pipelines and relevant unittests Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/13499600 * add quant features * resolve import * resolve format * fix save vocab --- examples/pytorch/llm/llm_sft.py | 2 - modelscope/models/nlp/qwen/tokenization.py | 14 ++-- .../pipelines/nlp/text_generation_pipeline.py | 39 +++++++++-- .../test_qwen_text_generation_pipeline.py | 65 +++++++++++++++++++ 4 files changed, 105 insertions(+), 15 deletions(-) diff --git a/examples/pytorch/llm/llm_sft.py b/examples/pytorch/llm/llm_sft.py index 7bcda665..aea3955c 100644 --- a/examples/pytorch/llm/llm_sft.py +++ b/examples/pytorch/llm/llm_sft.py @@ -13,9 +13,7 @@ cd modelscope pip install -r requirements.txt pip install . """ - import os -# os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' from dataclasses import dataclass, field from functools import partial from types import MethodType diff --git a/modelscope/models/nlp/qwen/tokenization.py b/modelscope/models/nlp/qwen/tokenization.py index 2acc9b87..459ceaf3 100644 --- a/modelscope/models/nlp/qwen/tokenization.py +++ b/modelscope/models/nlp/qwen/tokenization.py @@ -150,19 +150,19 @@ class QWenTokenizer(PreTrainedTokenizer): format(len(ids), self.max_len)) return ids - def save_vocabulary(self, - save_directory: str, - filename_prefix: Optional[str] = None) -> Tuple[str]: + def save_vocabulary(self, save_directory: str) -> Tuple[str]: """ Save only the vocabulary of the tokenizer (vocabulary + added tokens). Returns: `Tuple(str)`: Paths to the files saved. """ - file_path = os.path.join(save_directory, filename_prefix) - with open(file_path, 'w') as f: - json.dump(f, self.mergeable_ranks) - return file_path + file_path = os.path.join(save_directory, 'qwen.tiktoken') + with open(file_path, 'w', encoding='utf8') as w: + for k, v in self.mergeable_ranks.items(): + line = base64.b64encode(k).decode('utf8') + ' ' + str(v) + '\n' + w.write(line) + return (file_path, ) def tokenize(self, text: str, **kwargs) -> List[str]: """ diff --git a/modelscope/pipelines/nlp/text_generation_pipeline.py b/modelscope/pipelines/nlp/text_generation_pipeline.py index ea88baa2..2a872f7f 100644 --- a/modelscope/pipelines/nlp/text_generation_pipeline.py +++ b/modelscope/pipelines/nlp/text_generation_pipeline.py @@ -199,7 +199,8 @@ class ChatGLM6bTextGenerationPipeline(Pipeline): quantization_bit=None, use_bf16=False, **kwargs): - from modelscope.models.nlp.chatglm.text_generation import ChatGLMForConditionalGeneration, ChatGLMConfig + from modelscope.models.nlp.chatglm.text_generation import ( + ChatGLMConfig, ChatGLMForConditionalGeneration) if isinstance(model, str): model_dir = snapshot_download( model) if not os.path.exists(model) else model @@ -241,7 +242,9 @@ class ChatGLM6bV2TextGenerationPipeline(Pipeline): quantization_bit=None, use_bf16=False, **kwargs): - from modelscope.models.nlp import ChatGLM2ForConditionalGeneration, ChatGLM2Tokenizer, ChatGLM2Config + from modelscope.models.nlp import (ChatGLM2Config, + ChatGLM2ForConditionalGeneration, + ChatGLM2Tokenizer) if isinstance(model, str): model_dir = snapshot_download( model) if not os.path.exists(model) else model @@ -279,9 +282,19 @@ class ChatGLM6bV2TextGenerationPipeline(Pipeline): class QWenChatPipeline(Pipeline): def __init__(self, model: Union[Model, str], **kwargs): - from modelscope.models.nlp import QWenConfig, QWenTokenizer, QWenForTextGeneration + from modelscope.models.nlp import (QWenConfig, QWenForTextGeneration, + QWenTokenizer) torch_dtype = kwargs.get('torch_dtype', torch.bfloat16) device_map = kwargs.get('device_map', 'auto') + use_max_memory = kwargs.get('use_max_memory', False) + quantization_config = kwargs.get('quantization_config', None) + + if use_max_memory: + max_memory = f'{int(torch.cuda.mem_get_info()[0] / 1024 ** 3) - 2}GB' + n_gpus = torch.cuda.device_count() + max_memory = {i: max_memory for i in range(n_gpus)} + else: + max_memory = None if isinstance(model, str): model_dir = snapshot_download( @@ -296,7 +309,9 @@ class QWenChatPipeline(Pipeline): cfg_dict=config, config=model_config, device_map=device_map, - torch_dtype=torch_dtype) + torch_dtype=torch_dtype, + quantization_config=quantization_config, + max_memory=max_memory) model.generation_config = GenerationConfig.from_pretrained( model_dir) @@ -330,9 +345,19 @@ class QWenChatPipeline(Pipeline): class QWenTextGenerationPipeline(Pipeline): def __init__(self, model: Union[Model, str], **kwargs): - from modelscope.models.nlp import QWenConfig, QWenTokenizer, QWenForTextGeneration + from modelscope.models.nlp import (QWenConfig, QWenForTextGeneration, + QWenTokenizer) torch_dtype = kwargs.get('torch_dtype', torch.bfloat16) device_map = kwargs.get('device_map', 'auto') + use_max_memory = kwargs.get('use_max_memory', False) + quantization_config = kwargs.get('quantization_config', None) + + if use_max_memory: + max_memory = f'{int(torch.cuda.mem_get_info()[0] / 1024 ** 3) - 2}GB' + n_gpus = torch.cuda.device_count() + max_memory = {i: max_memory for i in range(n_gpus)} + else: + max_memory = None if isinstance(model, str): model_dir = snapshot_download( @@ -347,7 +372,9 @@ class QWenTextGenerationPipeline(Pipeline): cfg_dict=config, config=model_config, device_map=device_map, - torch_dtype=torch_dtype) + torch_dtype=torch_dtype, + quantization_config=quantization_config, + max_memory=max_memory) model.generation_config = GenerationConfig.from_pretrained( model_dir) diff --git a/tests/pipelines/test_qwen_text_generation_pipeline.py b/tests/pipelines/test_qwen_text_generation_pipeline.py index d8e3848b..48dba5aa 100644 --- a/tests/pipelines/test_qwen_text_generation_pipeline.py +++ b/tests/pipelines/test_qwen_text_generation_pipeline.py @@ -3,6 +3,7 @@ import unittest import torch +from transformers import BitsAndBytesConfig from modelscope.pipelines import pipeline from modelscope.utils.constant import Tasks @@ -65,6 +66,37 @@ class QWenTextGenerationPipelineTest(unittest.TestCase): 'device_map': 'auto', }) + # 7B_ms_base + @unittest.skipUnless(test_level() >= 3, 'skip test in current test level') + def test_qwen_base_with_text_generation_quant_int8(self): + quantization_config = BitsAndBytesConfig(load_in_8bit=True) + + self.run_pipeline_with_model_id( + self.qwen_base, + self.qwen_base_input, + init_kwargs={ + 'device_map': 'auto', + 'use_max_memory': True, + 'quantization_config': quantization_config, + }) + + # 7B_ms_base + @unittest.skipUnless(test_level() >= 3, 'skip test in current test level') + def test_qwen_base_with_text_generation_quant_int4(self): + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type='nf4', + bnb_4bit_compute_dtype=torch.bfloat16) + + self.run_pipeline_with_model_id( + self.qwen_base, + self.qwen_base_input, + init_kwargs={ + 'device_map': 'auto', + 'use_max_memory': True, + 'quantization_config': quantization_config, + }) + # 7B_ms_chat @unittest.skipUnless(test_level() >= 3, 'skip test in current test level') def test_qwen_chat_with_chat(self): @@ -76,6 +108,39 @@ class QWenTextGenerationPipelineTest(unittest.TestCase): 'device_map': 'auto', }) + # 7B_ms_chat + @unittest.skipUnless(test_level() >= 3, 'skip test in current test level') + def test_qwen_chat_with_chat_quant_int8(self): + quantization_config = BitsAndBytesConfig(load_in_8bit=True) + + self.run_chat_pipeline_with_model_id( + self.qwen_chat, + self.qwen_chat_input, + self.qwen_chat_system, + init_kwargs={ + 'device_map': 'auto', + 'use_max_memory': True, + 'quantization_config': quantization_config, + }) + + # 7B_ms_base + @unittest.skipUnless(test_level() >= 3, 'skip test in current test level') + def test_qwen_chat_with_chat_quant_int4(self): + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type='nf4', + bnb_4bit_compute_dtype=torch.bfloat16) + + self.run_chat_pipeline_with_model_id( + self.qwen_chat, + self.qwen_chat_input, + self.qwen_chat_system, + init_kwargs={ + 'device_map': 'auto', + 'use_max_memory': True, + 'quantization_config': quantization_config, + }) + if __name__ == '__main__': unittest.main()