Files
modelscope/examples/pytorch/llm/_parser.py
2023-07-26 18:12:55 +08:00

70 lines
2.0 KiB
Python

import os
from dataclasses import dataclass, field
from typing import List, Optional, Tuple, Type, TypeVar, Union
import torch
from torch import device as Device
from transformers import HfArgumentParser
from modelscope import get_logger
logger = get_logger()
def _format_device(device: Union[List[int], str]) -> Tuple[List[int], str]:
if isinstance(device, list):
device_ids = device
device_str = ','.join([str(d) for d in device])
else:
device_ids = [int(d) for d in device.split(',') if d != '-1']
device_str = device
device_str = device_str.replace(' ', '')
return device_ids, device_str
def select_device(device: Union[List[int], str]) -> Device:
"""Call this function before cuda is initialized.
device: e.g. []: 'cpu', [0], [0, 1, 2]
e.g. '-1': 'cpu', '0', '0,1,2'
"""
if torch.cuda.is_initialized():
logger.warning('CUDA has been initialized! Device selection fails!')
return torch.device('cuda:0')
device_ids, device_str = _format_device(device)
os.environ['CUDA_VISIBLE_DEVICES'] = device_str
log_s = 'Using device: '
if len(device_ids) == 0:
master_device: str = 'cpu'
log_s += 'cpu'
else:
assert torch.cuda.is_available(
) and torch.cuda.device_count() >= len(device_ids)
master_device = 'cuda:0'
log_s += f'cuda:{device_str}'
logger.info(log_s)
return torch.device(master_device)
_T = TypeVar('_T')
def parse_args(class_type: Type[_T],
argv: Optional[List[str]] = None) -> Tuple[_T, List[str]]:
parser = HfArgumentParser([class_type])
args, remaining_args = parser.parse_args_into_dataclasses(
argv, return_remaining_strings=True)
logger.info(f'args: {args}')
return args, remaining_args
@dataclass
class DeviceArguments:
device: str = '0' # e.g. '-1'; '0'; '0,1'
def parse_device(argv: Optional[List[str]] = None) -> List[str]:
args, remaining_args = parse_args(DeviceArguments, argv)
select_device(args.device)
return remaining_args