mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 20:19:51 +01:00
add quantization in qwen pipelines and relevant unittests
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/13499600 * add quant features * resolve import * resolve format * fix save vocab
This commit is contained in:
committed by
wenmeng.zwm
parent
d160ff1e8b
commit
2b27144384
@@ -13,9 +13,7 @@ cd modelscope
|
||||
pip install -r requirements.txt
|
||||
pip install .
|
||||
"""
|
||||
|
||||
import os
|
||||
# os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from types import MethodType
|
||||
|
||||
@@ -150,19 +150,19 @@ class QWenTokenizer(PreTrainedTokenizer):
|
||||
format(len(ids), self.max_len))
|
||||
return ids
|
||||
|
||||
def save_vocabulary(self,
|
||||
save_directory: str,
|
||||
filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
def save_vocabulary(self, save_directory: str) -> Tuple[str]:
|
||||
"""
|
||||
Save only the vocabulary of the tokenizer (vocabulary + added tokens).
|
||||
|
||||
Returns:
|
||||
`Tuple(str)`: Paths to the files saved.
|
||||
"""
|
||||
file_path = os.path.join(save_directory, filename_prefix)
|
||||
with open(file_path, 'w') as f:
|
||||
json.dump(f, self.mergeable_ranks)
|
||||
return file_path
|
||||
file_path = os.path.join(save_directory, 'qwen.tiktoken')
|
||||
with open(file_path, 'w', encoding='utf8') as w:
|
||||
for k, v in self.mergeable_ranks.items():
|
||||
line = base64.b64encode(k).decode('utf8') + ' ' + str(v) + '\n'
|
||||
w.write(line)
|
||||
return (file_path, )
|
||||
|
||||
def tokenize(self, text: str, **kwargs) -> List[str]:
|
||||
"""
|
||||
|
||||
@@ -199,7 +199,8 @@ class ChatGLM6bTextGenerationPipeline(Pipeline):
|
||||
quantization_bit=None,
|
||||
use_bf16=False,
|
||||
**kwargs):
|
||||
from modelscope.models.nlp.chatglm.text_generation import ChatGLMForConditionalGeneration, ChatGLMConfig
|
||||
from modelscope.models.nlp.chatglm.text_generation import (
|
||||
ChatGLMConfig, ChatGLMForConditionalGeneration)
|
||||
if isinstance(model, str):
|
||||
model_dir = snapshot_download(
|
||||
model) if not os.path.exists(model) else model
|
||||
@@ -241,7 +242,9 @@ class ChatGLM6bV2TextGenerationPipeline(Pipeline):
|
||||
quantization_bit=None,
|
||||
use_bf16=False,
|
||||
**kwargs):
|
||||
from modelscope.models.nlp import ChatGLM2ForConditionalGeneration, ChatGLM2Tokenizer, ChatGLM2Config
|
||||
from modelscope.models.nlp import (ChatGLM2Config,
|
||||
ChatGLM2ForConditionalGeneration,
|
||||
ChatGLM2Tokenizer)
|
||||
if isinstance(model, str):
|
||||
model_dir = snapshot_download(
|
||||
model) if not os.path.exists(model) else model
|
||||
@@ -279,9 +282,19 @@ class ChatGLM6bV2TextGenerationPipeline(Pipeline):
|
||||
class QWenChatPipeline(Pipeline):
|
||||
|
||||
def __init__(self, model: Union[Model, str], **kwargs):
|
||||
from modelscope.models.nlp import QWenConfig, QWenTokenizer, QWenForTextGeneration
|
||||
from modelscope.models.nlp import (QWenConfig, QWenForTextGeneration,
|
||||
QWenTokenizer)
|
||||
torch_dtype = kwargs.get('torch_dtype', torch.bfloat16)
|
||||
device_map = kwargs.get('device_map', 'auto')
|
||||
use_max_memory = kwargs.get('use_max_memory', False)
|
||||
quantization_config = kwargs.get('quantization_config', None)
|
||||
|
||||
if use_max_memory:
|
||||
max_memory = f'{int(torch.cuda.mem_get_info()[0] / 1024 ** 3) - 2}GB'
|
||||
n_gpus = torch.cuda.device_count()
|
||||
max_memory = {i: max_memory for i in range(n_gpus)}
|
||||
else:
|
||||
max_memory = None
|
||||
|
||||
if isinstance(model, str):
|
||||
model_dir = snapshot_download(
|
||||
@@ -296,7 +309,9 @@ class QWenChatPipeline(Pipeline):
|
||||
cfg_dict=config,
|
||||
config=model_config,
|
||||
device_map=device_map,
|
||||
torch_dtype=torch_dtype)
|
||||
torch_dtype=torch_dtype,
|
||||
quantization_config=quantization_config,
|
||||
max_memory=max_memory)
|
||||
model.generation_config = GenerationConfig.from_pretrained(
|
||||
model_dir)
|
||||
|
||||
@@ -330,9 +345,19 @@ class QWenChatPipeline(Pipeline):
|
||||
class QWenTextGenerationPipeline(Pipeline):
|
||||
|
||||
def __init__(self, model: Union[Model, str], **kwargs):
|
||||
from modelscope.models.nlp import QWenConfig, QWenTokenizer, QWenForTextGeneration
|
||||
from modelscope.models.nlp import (QWenConfig, QWenForTextGeneration,
|
||||
QWenTokenizer)
|
||||
torch_dtype = kwargs.get('torch_dtype', torch.bfloat16)
|
||||
device_map = kwargs.get('device_map', 'auto')
|
||||
use_max_memory = kwargs.get('use_max_memory', False)
|
||||
quantization_config = kwargs.get('quantization_config', None)
|
||||
|
||||
if use_max_memory:
|
||||
max_memory = f'{int(torch.cuda.mem_get_info()[0] / 1024 ** 3) - 2}GB'
|
||||
n_gpus = torch.cuda.device_count()
|
||||
max_memory = {i: max_memory for i in range(n_gpus)}
|
||||
else:
|
||||
max_memory = None
|
||||
|
||||
if isinstance(model, str):
|
||||
model_dir = snapshot_download(
|
||||
@@ -347,7 +372,9 @@ class QWenTextGenerationPipeline(Pipeline):
|
||||
cfg_dict=config,
|
||||
config=model_config,
|
||||
device_map=device_map,
|
||||
torch_dtype=torch_dtype)
|
||||
torch_dtype=torch_dtype,
|
||||
quantization_config=quantization_config,
|
||||
max_memory=max_memory)
|
||||
model.generation_config = GenerationConfig.from_pretrained(
|
||||
model_dir)
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from transformers import BitsAndBytesConfig
|
||||
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
@@ -65,6 +66,37 @@ class QWenTextGenerationPipelineTest(unittest.TestCase):
|
||||
'device_map': 'auto',
|
||||
})
|
||||
|
||||
# 7B_ms_base
|
||||
@unittest.skipUnless(test_level() >= 3, 'skip test in current test level')
|
||||
def test_qwen_base_with_text_generation_quant_int8(self):
|
||||
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
|
||||
self.run_pipeline_with_model_id(
|
||||
self.qwen_base,
|
||||
self.qwen_base_input,
|
||||
init_kwargs={
|
||||
'device_map': 'auto',
|
||||
'use_max_memory': True,
|
||||
'quantization_config': quantization_config,
|
||||
})
|
||||
|
||||
# 7B_ms_base
|
||||
@unittest.skipUnless(test_level() >= 3, 'skip test in current test level')
|
||||
def test_qwen_base_with_text_generation_quant_int4(self):
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type='nf4',
|
||||
bnb_4bit_compute_dtype=torch.bfloat16)
|
||||
|
||||
self.run_pipeline_with_model_id(
|
||||
self.qwen_base,
|
||||
self.qwen_base_input,
|
||||
init_kwargs={
|
||||
'device_map': 'auto',
|
||||
'use_max_memory': True,
|
||||
'quantization_config': quantization_config,
|
||||
})
|
||||
|
||||
# 7B_ms_chat
|
||||
@unittest.skipUnless(test_level() >= 3, 'skip test in current test level')
|
||||
def test_qwen_chat_with_chat(self):
|
||||
@@ -76,6 +108,39 @@ class QWenTextGenerationPipelineTest(unittest.TestCase):
|
||||
'device_map': 'auto',
|
||||
})
|
||||
|
||||
# 7B_ms_chat
|
||||
@unittest.skipUnless(test_level() >= 3, 'skip test in current test level')
|
||||
def test_qwen_chat_with_chat_quant_int8(self):
|
||||
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
|
||||
self.run_chat_pipeline_with_model_id(
|
||||
self.qwen_chat,
|
||||
self.qwen_chat_input,
|
||||
self.qwen_chat_system,
|
||||
init_kwargs={
|
||||
'device_map': 'auto',
|
||||
'use_max_memory': True,
|
||||
'quantization_config': quantization_config,
|
||||
})
|
||||
|
||||
# 7B_ms_base
|
||||
@unittest.skipUnless(test_level() >= 3, 'skip test in current test level')
|
||||
def test_qwen_chat_with_chat_quant_int4(self):
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type='nf4',
|
||||
bnb_4bit_compute_dtype=torch.bfloat16)
|
||||
|
||||
self.run_chat_pipeline_with_model_id(
|
||||
self.qwen_chat,
|
||||
self.qwen_chat_input,
|
||||
self.qwen_chat_system,
|
||||
init_kwargs={
|
||||
'device_map': 'auto',
|
||||
'use_max_memory': True,
|
||||
'quantization_config': quantization_config,
|
||||
})
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user