support openbuddy-llama2-13b (#416)

This commit is contained in:
Jintao
2023-07-26 18:12:55 +08:00
committed by GitHub
parent ba4db97507
commit 4ca937d2ba
9 changed files with 385 additions and 296 deletions

View File

@@ -0,0 +1,69 @@
import os
from dataclasses import dataclass, field
from typing import List, Optional, Tuple, Type, TypeVar, Union
import torch
from torch import device as Device
from transformers import HfArgumentParser
from modelscope import get_logger
logger = get_logger()
def _format_device(device: Union[List[int], str]) -> Tuple[List[int], str]:
if isinstance(device, list):
device_ids = device
device_str = ','.join([str(d) for d in device])
else:
device_ids = [int(d) for d in device.split(',') if d != '-1']
device_str = device
device_str = device_str.replace(' ', '')
return device_ids, device_str
def select_device(device: Union[List[int], str]) -> Device:
"""Call this function before cuda is initialized.
device: e.g. []: 'cpu', [0], [0, 1, 2]
e.g. '-1': 'cpu', '0', '0,1,2'
"""
if torch.cuda.is_initialized():
logger.warning('CUDA has been initialized! Device selection fails!')
return torch.device('cuda:0')
device_ids, device_str = _format_device(device)
os.environ['CUDA_VISIBLE_DEVICES'] = device_str
log_s = 'Using device: '
if len(device_ids) == 0:
master_device: str = 'cpu'
log_s += 'cpu'
else:
assert torch.cuda.is_available(
) and torch.cuda.device_count() >= len(device_ids)
master_device = 'cuda:0'
log_s += f'cuda:{device_str}'
logger.info(log_s)
return torch.device(master_device)
_T = TypeVar('_T')
def parse_args(class_type: Type[_T],
argv: Optional[List[str]] = None) -> Tuple[_T, List[str]]:
parser = HfArgumentParser([class_type])
args, remaining_args = parser.parse_args_into_dataclasses(
argv, return_remaining_strings=True)
logger.info(f'args: {args}')
return args, remaining_args
@dataclass
class DeviceArguments:
device: str = '0' # e.g. '-1'; '0'; '0,1'
def parse_device(argv: Optional[List[str]] = None) -> List[str]:
args, remaining_args = parse_args(DeviceArguments, argv)
select_device(args.device)
return remaining_args