mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-23 19:49:24 +01:00
70 lines
2.0 KiB
Python
70 lines
2.0 KiB
Python
import os
|
|
from dataclasses import dataclass, field
|
|
from typing import List, Optional, Tuple, Type, TypeVar, Union
|
|
|
|
import torch
|
|
from torch import device as Device
|
|
from transformers import HfArgumentParser
|
|
|
|
from modelscope import get_logger
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
def _format_device(device: Union[List[int], str]) -> Tuple[List[int], str]:
|
|
if isinstance(device, list):
|
|
device_ids = device
|
|
device_str = ','.join([str(d) for d in device])
|
|
else:
|
|
device_ids = [int(d) for d in device.split(',') if d != '-1']
|
|
device_str = device
|
|
device_str = device_str.replace(' ', '')
|
|
return device_ids, device_str
|
|
|
|
|
|
def select_device(device: Union[List[int], str]) -> Device:
|
|
"""Call this function before cuda is initialized.
|
|
device: e.g. []: 'cpu', [0], [0, 1, 2]
|
|
e.g. '-1': 'cpu', '0', '0,1,2'
|
|
"""
|
|
if torch.cuda.is_initialized():
|
|
logger.warning('CUDA has been initialized! Device selection fails!')
|
|
return torch.device('cuda:0')
|
|
|
|
device_ids, device_str = _format_device(device)
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = device_str
|
|
log_s = 'Using device: '
|
|
if len(device_ids) == 0:
|
|
master_device: str = 'cpu'
|
|
log_s += 'cpu'
|
|
else:
|
|
assert torch.cuda.is_available(
|
|
) and torch.cuda.device_count() >= len(device_ids)
|
|
master_device = 'cuda:0'
|
|
log_s += f'cuda:{device_str}'
|
|
logger.info(log_s)
|
|
return torch.device(master_device)
|
|
|
|
|
|
_T = TypeVar('_T')
|
|
|
|
|
|
def parse_args(class_type: Type[_T],
|
|
argv: Optional[List[str]] = None) -> Tuple[_T, List[str]]:
|
|
parser = HfArgumentParser([class_type])
|
|
args, remaining_args = parser.parse_args_into_dataclasses(
|
|
argv, return_remaining_strings=True)
|
|
logger.info(f'args: {args}')
|
|
return args, remaining_args
|
|
|
|
|
|
@dataclass
|
|
class DeviceArguments:
|
|
device: str = '0' # e.g. '-1'; '0'; '0,1'
|
|
|
|
|
|
def parse_device(argv: Optional[List[str]] = None) -> List[str]:
|
|
args, remaining_args = parse_args(DeviceArguments, argv)
|
|
select_device(args.device)
|
|
return remaining_args
|