mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 03:59:23 +01:00
support openbuddy-llama2-13b (#416)
This commit is contained in:
69
examples/pytorch/llm/_parser.py
Normal file
69
examples/pytorch/llm/_parser.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
import torch
|
||||
from torch import device as Device
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
from modelscope import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def _format_device(device: Union[List[int], str]) -> Tuple[List[int], str]:
|
||||
if isinstance(device, list):
|
||||
device_ids = device
|
||||
device_str = ','.join([str(d) for d in device])
|
||||
else:
|
||||
device_ids = [int(d) for d in device.split(',') if d != '-1']
|
||||
device_str = device
|
||||
device_str = device_str.replace(' ', '')
|
||||
return device_ids, device_str
|
||||
|
||||
|
||||
def select_device(device: Union[List[int], str]) -> Device:
|
||||
"""Call this function before cuda is initialized.
|
||||
device: e.g. []: 'cpu', [0], [0, 1, 2]
|
||||
e.g. '-1': 'cpu', '0', '0,1,2'
|
||||
"""
|
||||
if torch.cuda.is_initialized():
|
||||
logger.warning('CUDA has been initialized! Device selection fails!')
|
||||
return torch.device('cuda:0')
|
||||
|
||||
device_ids, device_str = _format_device(device)
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = device_str
|
||||
log_s = 'Using device: '
|
||||
if len(device_ids) == 0:
|
||||
master_device: str = 'cpu'
|
||||
log_s += 'cpu'
|
||||
else:
|
||||
assert torch.cuda.is_available(
|
||||
) and torch.cuda.device_count() >= len(device_ids)
|
||||
master_device = 'cuda:0'
|
||||
log_s += f'cuda:{device_str}'
|
||||
logger.info(log_s)
|
||||
return torch.device(master_device)
|
||||
|
||||
|
||||
_T = TypeVar('_T')
|
||||
|
||||
|
||||
def parse_args(class_type: Type[_T],
|
||||
argv: Optional[List[str]] = None) -> Tuple[_T, List[str]]:
|
||||
parser = HfArgumentParser([class_type])
|
||||
args, remaining_args = parser.parse_args_into_dataclasses(
|
||||
argv, return_remaining_strings=True)
|
||||
logger.info(f'args: {args}')
|
||||
return args, remaining_args
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeviceArguments:
|
||||
device: str = '0' # e.g. '-1'; '0'; '0,1'
|
||||
|
||||
|
||||
def parse_device(argv: Optional[List[str]] = None) -> List[str]:
|
||||
args, remaining_args = parse_args(DeviceArguments, argv)
|
||||
select_device(args.device)
|
||||
return remaining_args
|
||||
Reference in New Issue
Block a user