mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-22 19:19:21 +01:00
Fix/chatglm2 (#384)
This commit is contained in:
@@ -51,25 +51,15 @@ from modelscope.utils.config import Config, ConfigDict
|
||||
from modelscope.utils.registry import default_group
|
||||
|
||||
#
|
||||
TEST_SPLIT_P = 0.01
|
||||
SPLIT_SEED = 42
|
||||
MAX_LENGTH: Optional[int] = 2048
|
||||
COLOR, COLOR_S = '#FFE2D9', '#FF7043'
|
||||
|
||||
PROMPT = """### 用户
|
||||
{instruction}
|
||||
### AI助手
|
||||
"""
|
||||
PROMPT = """Human: {instruction}
|
||||
AI: """
|
||||
|
||||
logger = get_logger()
|
||||
#
|
||||
|
||||
|
||||
def get_model_dir(model_id: str, model_revision: Optional[str] = None) -> str:
|
||||
model_dir = snapshot_download(model_id, model_revision)
|
||||
return model_dir
|
||||
|
||||
|
||||
def _get_version(work_dir: str) -> int:
|
||||
if os.path.isdir(work_dir):
|
||||
fnames = os.listdir(work_dir)
|
||||
@@ -96,28 +86,40 @@ def get_work_dir(work_dir: str) -> str:
|
||||
return work_dir
|
||||
|
||||
|
||||
def select_device(device_ids: List[int]) -> Device:
|
||||
def _format_device(device: Union[List[int], str]) -> Tuple[List[int], str]:
|
||||
if isinstance(device, list):
|
||||
device_ids = device
|
||||
device_str = ','.join([str(d) for d in device])
|
||||
else:
|
||||
device_ids = [int(d) for d in device.split(',') if d != '-1']
|
||||
device_str = device
|
||||
device_str = device_str.replace(' ', '')
|
||||
return device_ids, device_str
|
||||
|
||||
|
||||
def select_device(device: Union[List[int], str]) -> Device:
|
||||
"""Call this function before cuda is initialized.
|
||||
Return: master device
|
||||
device: e.g. []: 'cpu', [0], [0, 1, 2]
|
||||
e.g. '-1': 'cpu', '0', '0,1,2'
|
||||
"""
|
||||
if torch.cuda.is_initialized():
|
||||
logger.warning('CUDA has been initialized! Device selection fails!')
|
||||
return torch.device('cuda:0')
|
||||
#
|
||||
device_ids, device_str = _format_device(device)
|
||||
#
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = device_str
|
||||
log_s = 'Using device: '
|
||||
if len(device_ids) == 0: # cpu
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
||||
device: str = 'cpu'
|
||||
log_s += device
|
||||
if len(device_ids) == 0:
|
||||
master_device: str = 'cpu'
|
||||
log_s += 'cpu'
|
||||
else:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
|
||||
[str(d) for d in device_ids])
|
||||
assert torch.cuda.is_available(
|
||||
) and torch.cuda.device_count() >= len(device_ids)
|
||||
log_s += f"cuda:{','.join([str(d) for d in device_ids])}" # e.g. 'cuda:1,7,8'
|
||||
device = 'cuda:0'
|
||||
master_device = 'cuda:0'
|
||||
log_s += f'cuda:{device_str}'
|
||||
logger.info(log_s)
|
||||
return torch.device(device)
|
||||
return torch.device(master_device)
|
||||
|
||||
|
||||
def seed_everything(seed: Optional[int] = None, gpu_dtm: bool = False) -> int:
|
||||
@@ -148,7 +150,9 @@ def get_T_max(dataset_len: int, batch_size: int, max_epochs: int,
|
||||
return T_max
|
||||
|
||||
|
||||
def tokenize_function(example: Dict[str, str], tokenizer) -> Dict[str, Any]:
|
||||
def tokenize_function(example: Dict[str, str],
|
||||
tokenizer,
|
||||
max_length: Optional[int] = 2048) -> Dict[str, Any]:
|
||||
"""Only applicable to baichuan and chatglm2. Other models need to be tested"""
|
||||
instruction = example['instruction']
|
||||
input_: str = example['input']
|
||||
@@ -159,12 +163,12 @@ def tokenize_function(example: Dict[str, str], tokenizer) -> Dict[str, Any]:
|
||||
else:
|
||||
instruction = instruction + input_
|
||||
output = example['output']
|
||||
src_text = PROMPT.format(instruction=instruction, add_special_tokens=False)
|
||||
src_text = PROMPT.format(instruction=instruction)
|
||||
src_input_ids: List[int] = tokenizer(
|
||||
src_text, return_attention_mask=False,
|
||||
add_special_tokens=True)['input_ids']
|
||||
# tokenizer.bos_token_id: Avoid `tgt_input_ids` being empty
|
||||
tgt_input_ids = [tokenizer.bos_token_id]
|
||||
#
|
||||
tgt_input_ids = []
|
||||
if output is not None:
|
||||
tgt_input_ids += tokenizer(
|
||||
output, return_attention_mask=False,
|
||||
@@ -175,10 +179,10 @@ def tokenize_function(example: Dict[str, str], tokenizer) -> Dict[str, Any]:
|
||||
labels = None
|
||||
input_ids = src_input_ids + tgt_input_ids
|
||||
#
|
||||
if MAX_LENGTH is not None:
|
||||
input_ids = input_ids[-MAX_LENGTH:]
|
||||
if max_length is not None:
|
||||
input_ids = input_ids[-max_length:]
|
||||
if labels is not None:
|
||||
labels = labels[-MAX_LENGTH:]
|
||||
labels = labels[-max_length:]
|
||||
#
|
||||
return {'input_ids': input_ids, 'labels': labels}
|
||||
|
||||
@@ -200,8 +204,10 @@ def stat_dataset(dataset: HFDataset) -> None:
|
||||
|
||||
def print_examples(examples: Dict[str, Any], tokenizer) -> None:
|
||||
input_ids, labels = examples['input_ids'], examples['labels']
|
||||
print(f'[INPUT_IDS] {tokenizer.decode(input_ids)}')
|
||||
print(f'[INPUT_IDS] {input_ids}')
|
||||
print(f'[INPUT] {tokenizer.decode(input_ids)}')
|
||||
print()
|
||||
print(f'[LABLES_IDS] {labels}')
|
||||
print(
|
||||
f'[LABLES] {tokenizer.decode([lb if lb != -100 else 0 for lb in labels])}'
|
||||
)
|
||||
@@ -283,16 +289,25 @@ class MyMetric(Metric):
|
||||
}
|
||||
|
||||
def merge(self, other: 'MyMetric') -> None:
|
||||
"""This script does not support ddp"""
|
||||
"""This script does not support ddp. TODO"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def get_baichuan7B_model_tokenizer(model_dir: Optional[str] = None,
|
||||
load_model: bool = True):
|
||||
if model_dir is None:
|
||||
model_id = 'baichuan-inc/baichuan-7B'
|
||||
model_dir = get_model_dir(model_id, None)
|
||||
#
|
||||
def _add_special_token(tokenizer):
|
||||
if tokenizer.eos_token_id is None:
|
||||
tokenizer.eos_token_id = 2
|
||||
if tokenizer.bos_token_id is None:
|
||||
tokenizer.bos_token_id = 1
|
||||
if tokenizer.pad_token_id is None:
|
||||
tokenizer.pad_token_id = 0
|
||||
logger.info(f'bos_token_id: {tokenizer.bos_token_id}, '
|
||||
f'eos_token_id: {tokenizer.eos_token_id}, '
|
||||
f'pad_token_id: {tokenizer.pad_token_id}')
|
||||
|
||||
|
||||
def get_baichuan7B_model_tokenizer(model_dir: str,
|
||||
load_model: bool = True,
|
||||
add_special_token: bool = True):
|
||||
sys.path.insert(0, model_dir)
|
||||
from configuration_baichuan import BaiChuanConfig
|
||||
from tokenization_baichuan import BaiChuanTokenizer
|
||||
@@ -309,15 +324,14 @@ def get_baichuan7B_model_tokenizer(model_dir: Optional[str] = None,
|
||||
device_map='auto',
|
||||
torch_dtype=torch.float16)
|
||||
#
|
||||
if add_special_token:
|
||||
_add_special_token(tokenizer)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def get_baichuan13B_model_tokenizer(model_dir: Optional[str] = None,
|
||||
load_model: bool = True):
|
||||
if model_dir is None:
|
||||
model_id = 'baichuan-inc/Baichuan-13B-Base'
|
||||
model_dir = get_model_dir(model_id, 'v1.0.1')
|
||||
#
|
||||
def get_baichuan13B_model_tokenizer(model_dir: str,
|
||||
load_model: bool = True,
|
||||
add_special_token: bool = True):
|
||||
sys.path.insert(0, model_dir)
|
||||
from configuration_baichuan import BaichuanConfig
|
||||
from tokenization_baichuan import BaichuanTokenizer
|
||||
@@ -334,15 +348,14 @@ def get_baichuan13B_model_tokenizer(model_dir: Optional[str] = None,
|
||||
device_map='auto',
|
||||
torch_dtype=torch.float16)
|
||||
#
|
||||
if add_special_token:
|
||||
_add_special_token(tokenizer)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def get_chatglm2_model_tokenizer(model_dir: Optional[str] = None,
|
||||
load_model: bool = True):
|
||||
if model_dir is None:
|
||||
model_id = 'ZhipuAI/chatglm2-6b'
|
||||
model_dir = snapshot_download(model_id, None)
|
||||
#
|
||||
def get_chatglm2_model_tokenizer(model_dir: str,
|
||||
load_model: bool = True,
|
||||
add_special_token: bool = True):
|
||||
config = read_config(model_dir)
|
||||
config['model'] = ConfigDict({'type': 'chatglm2-6b'})
|
||||
tokenizer = ChatGLM2Tokenizer.from_pretrained(model_dir)
|
||||
@@ -353,12 +366,16 @@ def get_chatglm2_model_tokenizer(model_dir: Optional[str] = None,
|
||||
cfg_dict=config,
|
||||
device_map='auto',
|
||||
torch_dtype=torch.float16)
|
||||
if add_special_token:
|
||||
_add_special_token(tokenizer)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def get_alpaca_en_zh_dataset(
|
||||
tokenize_function,
|
||||
only_val: bool = False) -> Tuple[HFDataset, HFDataset]:
|
||||
only_val: bool = False,
|
||||
test_split_p: float = 0.01,
|
||||
split_seed: int = 42) -> Tuple[HFDataset, HFDataset]:
|
||||
"""
|
||||
split: Literal['train', 'validation', None]
|
||||
"""
|
||||
@@ -371,7 +388,7 @@ def get_alpaca_en_zh_dataset(
|
||||
dataset: HFDataset = concatenate_datasets([dataset_zh, dataset_en])
|
||||
#
|
||||
# dataset = dataset.select(range(1000)) # for debug
|
||||
dataset = dataset.train_test_split(TEST_SPLIT_P, seed=SPLIT_SEED)
|
||||
dataset = dataset.train_test_split(test_split_p, seed=split_seed)
|
||||
if only_val:
|
||||
dataset = dataset['test']
|
||||
if tokenize_function is not None:
|
||||
|
||||
Reference in New Issue
Block a user