mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-21 10:39:24 +01:00
Fix/chatglm2 (#384)
This commit is contained in:
@@ -49,11 +49,9 @@ from modelscope.utils.config import Config, ConfigDict
|
||||
from modelscope.utils.registry import default_group
|
||||
|
||||
#
|
||||
SYSTEM_TEXT = """{system}"""
|
||||
USER_TEXT = """\n\n### 用户
|
||||
{user}"""
|
||||
ASSISTANT_PROMPT = """\n\n### 助手
|
||||
"""
|
||||
PROMPT = """System: {system}
|
||||
Human: {user}
|
||||
AI: """
|
||||
MAX_LENGTH = 2048
|
||||
TEST_MAX_LENGTH = MAX_LENGTH
|
||||
|
||||
@@ -62,11 +60,6 @@ logger = get_logger()
|
||||
#
|
||||
|
||||
|
||||
def get_model_dir(model_id: str, model_revision: Optional[str] = None) -> str:
|
||||
model_dir = snapshot_download(model_id, model_revision)
|
||||
return model_dir
|
||||
|
||||
|
||||
def _get_version(work_dir: str) -> int:
|
||||
if os.path.isdir(work_dir):
|
||||
fnames = os.listdir(work_dir)
|
||||
@@ -93,28 +86,40 @@ def get_work_dir(work_dir: str) -> str:
|
||||
return work_dir
|
||||
|
||||
|
||||
def select_device(device_ids: List[int]) -> Device:
|
||||
def _format_device(device: Union[List[int], str]) -> Tuple[List[int], str]:
|
||||
if isinstance(device, list):
|
||||
device_ids = device
|
||||
device_str = ','.join([str(d) for d in device])
|
||||
else:
|
||||
device_ids = [int(d) for d in device.split(',') if d != '-1']
|
||||
device_str = device
|
||||
device_str = device_str.replace(' ', '')
|
||||
return device_ids, device_str
|
||||
|
||||
|
||||
def select_device(device: Union[List[int], str]) -> Device:
|
||||
"""Call this function before cuda is initialized.
|
||||
Return: master device
|
||||
device: e.g. []: 'cpu', [0], [0, 1, 2]
|
||||
e.g. '-1': 'cpu', '0', '0,1,2'
|
||||
"""
|
||||
if torch.cuda.is_initialized():
|
||||
logger.warning('CUDA has been initialized! Device selection fails!')
|
||||
return torch.device('cuda:0')
|
||||
#
|
||||
device_ids, device_str = _format_device(device)
|
||||
#
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = device_str
|
||||
log_s = 'Using device: '
|
||||
if len(device_ids) == 0: # cpu
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
||||
device: str = 'cpu'
|
||||
log_s += device
|
||||
if len(device_ids) == 0:
|
||||
master_device: str = 'cpu'
|
||||
log_s += 'cpu'
|
||||
else:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
|
||||
[str(d) for d in device_ids])
|
||||
assert torch.cuda.is_available(
|
||||
) and torch.cuda.device_count() >= len(device_ids)
|
||||
log_s += f"cuda:{','.join([str(d) for d in device_ids])}" # e.g. 'cuda:1,7,8'
|
||||
device = 'cuda:0'
|
||||
master_device = 'cuda:0'
|
||||
log_s += f'cuda:{device_str}'
|
||||
logger.info(log_s)
|
||||
return torch.device(device)
|
||||
return torch.device(master_device)
|
||||
|
||||
|
||||
def seed_everything(seed: Optional[int] = None, gpu_dtm: bool = False) -> int:
|
||||
@@ -148,37 +153,27 @@ def get_T_max(dataset_len: int, batch_size: int, max_epochs: int,
|
||||
def tokenize_function(system: str, user: str, assistant: Optional[str],
|
||||
tokenizer) -> Dict[str, Any]:
|
||||
"""Only applicable to baichuan and chatglm2. Other models need to be tested"""
|
||||
system_text = SYSTEM_TEXT.format(system=system)
|
||||
user_text = USER_TEXT.format(user=user)
|
||||
system_text_ids: List[int] = tokenizer(
|
||||
system_text, return_attention_mask=False,
|
||||
src_text = PROMPT.format(system=system, user=user)
|
||||
src_input_ids: List[int] = tokenizer(
|
||||
src_text, return_attention_mask=False,
|
||||
add_special_tokens=True)['input_ids']
|
||||
user_text_ids: List[int] = tokenizer(
|
||||
user_text, return_attention_mask=False,
|
||||
add_special_tokens=False)['input_ids']
|
||||
assistant_p_input_ids: List[int] = tokenizer(
|
||||
ASSISTANT_PROMPT,
|
||||
return_attention_mask=False,
|
||||
add_special_tokens=False)['input_ids']
|
||||
|
||||
# tokenizer.bos_token_id: Avoid `assistant` being empty
|
||||
assistant_input_ids: List[int] = [tokenizer.bos_token_id]
|
||||
#
|
||||
tgt_input_ids: List[int] = []
|
||||
if assistant is not None:
|
||||
assistant_input_ids += tokenizer(
|
||||
tgt_input_ids += tokenizer(
|
||||
assistant, return_attention_mask=False,
|
||||
add_special_tokens=False)['input_ids']
|
||||
assistant_input_ids += [tokenizer.eos_token_id]
|
||||
tgt_input_ids += [tokenizer.eos_token_id]
|
||||
labels = [-100] * len(src_input_ids) + tgt_input_ids
|
||||
else:
|
||||
labels = None
|
||||
input_ids = src_input_ids + tgt_input_ids
|
||||
#
|
||||
input_ids = system_text_ids + user_text_ids + assistant_p_input_ids + assistant_input_ids
|
||||
if assistant is not None: # train, val
|
||||
if assistant is not None:
|
||||
if len(input_ids) > MAX_LENGTH:
|
||||
return {}
|
||||
len_mask = len(input_ids) - len(assistant_input_ids)
|
||||
labels = [-100] * len_mask + assistant_input_ids
|
||||
else: # test
|
||||
else:
|
||||
input_ids = input_ids[-TEST_MAX_LENGTH:]
|
||||
labels = None
|
||||
|
||||
#
|
||||
return {'input_ids': input_ids, 'labels': labels}
|
||||
|
||||
@@ -305,12 +300,21 @@ class MyMetric(Metric):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def get_baichuan_model_tokenizer(model_dir: Optional[str] = None,
|
||||
load_model: bool = True):
|
||||
if model_dir is None:
|
||||
model_id = 'baichuan-inc/baichuan-7B'
|
||||
model_dir = get_model_dir(model_id, None)
|
||||
#
|
||||
def _add_special_token(tokenizer):
|
||||
if tokenizer.eos_token_id is None:
|
||||
tokenizer.eos_token_id = 2
|
||||
if tokenizer.bos_token_id is None:
|
||||
tokenizer.bos_token_id = 1
|
||||
if tokenizer.pad_token_id is None:
|
||||
tokenizer.pad_token_id = 0
|
||||
logger.info(f'bos_token_id: {tokenizer.bos_token_id}, '
|
||||
f'eos_token_id: {tokenizer.eos_token_id}, '
|
||||
f'pad_token_id: {tokenizer.pad_token_id}')
|
||||
|
||||
|
||||
def get_baichuan7B_model_tokenizer(model_dir: str,
|
||||
load_model: bool = True,
|
||||
add_special_token: bool = True):
|
||||
sys.path.insert(0, model_dir)
|
||||
from configuration_baichuan import BaiChuanConfig
|
||||
from tokenization_baichuan import BaiChuanTokenizer
|
||||
@@ -327,15 +331,14 @@ def get_baichuan_model_tokenizer(model_dir: Optional[str] = None,
|
||||
device_map='auto',
|
||||
torch_dtype=torch.float16)
|
||||
#
|
||||
if add_special_token:
|
||||
_add_special_token(tokenizer)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def get_chatglm2_model_tokenizer(model_dir: Optional[str] = None,
|
||||
load_model: bool = True):
|
||||
if model_dir is None:
|
||||
model_id = 'ZhipuAI/chatglm2-6b'
|
||||
model_dir = snapshot_download(model_id, None)
|
||||
#
|
||||
def get_chatglm2_model_tokenizer(model_dir: str,
|
||||
load_model: bool = True,
|
||||
add_special_token: bool = True):
|
||||
config = read_config(model_dir)
|
||||
config['model'] = ConfigDict({'type': 'chatglm2-6b'})
|
||||
tokenizer = ChatGLM2Tokenizer.from_pretrained(model_dir)
|
||||
@@ -346,6 +349,8 @@ def get_chatglm2_model_tokenizer(model_dir: Optional[str] = None,
|
||||
cfg_dict=config,
|
||||
device_map='auto',
|
||||
torch_dtype=torch.float16)
|
||||
if add_special_token:
|
||||
_add_special_token(tokenizer)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user