Support device_map for transformers model

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/13048856

* support device_map

* support device_map for T5

* fix bug
This commit is contained in:
hemu.zp
2023-06-25 20:15:11 +08:00
committed by wenmeng.zwm
parent 29062d9f94
commit 1421629392
3 changed files with 19 additions and 11 deletions

View File

@@ -57,7 +57,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
r'decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight',
]
def __init__(self, config: T5Config, **kwargs):
def __init__(self, config: T5Config, device_map=None, **kwargs):
super().__init__(config)
self.model_dim = config.d_model
@@ -82,7 +82,8 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
# Model parallel
self.model_parallel = False
self.device_map = None
if device_map == 'auto':
self.parallelize()
def parallelize(self, device_map=None):
self.device_map = (

View File

@@ -27,7 +27,8 @@ class UserSatisfactionEstimation(TorchModel):
def __init__(self,
model_dir: str,
bert_name: str = None,
device: str = None):
device: str = None,
**kwargs):
"""initialize the user satisfaction estimation model from the `model_dir` path. The default preprocessor
for this task is DialogueClassificationUsePreprocessor.
@@ -36,7 +37,7 @@ class UserSatisfactionEstimation(TorchModel):
bert_name: The pretrained model, default bert-base-chinese
device: The device of running model, default cpu
"""
super().__init__(model_dir)
super().__init__(model_dir, **kwargs)
self.model_dir = model_dir
self.bert_name = bert_name if bert_name is not None else 'bert-base-chinese'
self.device = 'cpu'

View File

@@ -54,7 +54,8 @@ class Pipeline(ABC):
model,
device=self.device_name,
model_prefetched=True,
invoked_by=Invoke.PIPELINE) if is_model(model) else model
invoked_by=Invoke.PIPELINE,
device_map=self.device_map) if is_model(model) else model
else:
return model
@@ -70,6 +71,7 @@ class Pipeline(ABC):
preprocessor: Union[Preprocessor, List[Preprocessor]] = None,
device: str = 'gpu',
auto_collate=True,
device_map=None,
**kwargs):
""" Base class for pipeline.
@@ -87,6 +89,9 @@ class Pipeline(ABC):
compile_options (dict, optional): The compile options if compile=True,
default None to use the default params of 'TorchModel.compile'.
"""
if device_map is not None:
assert device == 'gpu', '`device` and `device_map` cannot be input at the same time!'
self.device_map = device_map
verify_device(device)
self.device_name = device
@@ -133,13 +138,14 @@ class Pipeline(ABC):
self._model_prepare_lock.acquire(timeout=600)
def _prepare_single(model):
if isinstance(model, torch.nn.Module):
if not isinstance(model, torch.nn.Module) and hasattr(
model, 'model'):
model = model.model
if not isinstance(model, torch.nn.Module):
return
model.eval()
if self.device_map is None:
model.to(self.device)
model.eval()
elif hasattr(model, 'model') and isinstance(
model.model, torch.nn.Module):
model.model.to(self.device)
model.model.eval()
if not self._model_prepare:
# prepare model for pytorch