mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
Support device_map for transformers model
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/13048856 * support device_map * support device_map for T5 * fix bug
This commit is contained in:
@@ -57,7 +57,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
r'decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight',
|
||||
]
|
||||
|
||||
def __init__(self, config: T5Config, **kwargs):
|
||||
def __init__(self, config: T5Config, device_map=None, **kwargs):
|
||||
super().__init__(config)
|
||||
self.model_dim = config.d_model
|
||||
|
||||
@@ -82,7 +82,8 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
|
||||
# Model parallel
|
||||
self.model_parallel = False
|
||||
self.device_map = None
|
||||
if device_map == 'auto':
|
||||
self.parallelize()
|
||||
|
||||
def parallelize(self, device_map=None):
|
||||
self.device_map = (
|
||||
|
||||
@@ -27,7 +27,8 @@ class UserSatisfactionEstimation(TorchModel):
|
||||
def __init__(self,
|
||||
model_dir: str,
|
||||
bert_name: str = None,
|
||||
device: str = None):
|
||||
device: str = None,
|
||||
**kwargs):
|
||||
"""initialize the user satisfaction estimation model from the `model_dir` path. The default preprocessor
|
||||
for this task is DialogueClassificationUsePreprocessor.
|
||||
|
||||
@@ -36,7 +37,7 @@ class UserSatisfactionEstimation(TorchModel):
|
||||
bert_name: The pretrained model, default bert-base-chinese
|
||||
device: The device of running model, default cpu
|
||||
"""
|
||||
super().__init__(model_dir)
|
||||
super().__init__(model_dir, **kwargs)
|
||||
self.model_dir = model_dir
|
||||
self.bert_name = bert_name if bert_name is not None else 'bert-base-chinese'
|
||||
self.device = 'cpu'
|
||||
|
||||
@@ -54,7 +54,8 @@ class Pipeline(ABC):
|
||||
model,
|
||||
device=self.device_name,
|
||||
model_prefetched=True,
|
||||
invoked_by=Invoke.PIPELINE) if is_model(model) else model
|
||||
invoked_by=Invoke.PIPELINE,
|
||||
device_map=self.device_map) if is_model(model) else model
|
||||
else:
|
||||
return model
|
||||
|
||||
@@ -70,6 +71,7 @@ class Pipeline(ABC):
|
||||
preprocessor: Union[Preprocessor, List[Preprocessor]] = None,
|
||||
device: str = 'gpu',
|
||||
auto_collate=True,
|
||||
device_map=None,
|
||||
**kwargs):
|
||||
""" Base class for pipeline.
|
||||
|
||||
@@ -87,6 +89,9 @@ class Pipeline(ABC):
|
||||
compile_options (dict, optional): The compile options if compile=True,
|
||||
default None to use the default params of 'TorchModel.compile'.
|
||||
"""
|
||||
if device_map is not None:
|
||||
assert device == 'gpu', '`device` and `device_map` cannot be input at the same time!'
|
||||
self.device_map = device_map
|
||||
verify_device(device)
|
||||
self.device_name = device
|
||||
|
||||
@@ -133,13 +138,14 @@ class Pipeline(ABC):
|
||||
self._model_prepare_lock.acquire(timeout=600)
|
||||
|
||||
def _prepare_single(model):
|
||||
if isinstance(model, torch.nn.Module):
|
||||
if not isinstance(model, torch.nn.Module) and hasattr(
|
||||
model, 'model'):
|
||||
model = model.model
|
||||
if not isinstance(model, torch.nn.Module):
|
||||
return
|
||||
model.eval()
|
||||
if self.device_map is None:
|
||||
model.to(self.device)
|
||||
model.eval()
|
||||
elif hasattr(model, 'model') and isinstance(
|
||||
model.model, torch.nn.Module):
|
||||
model.model.to(self.device)
|
||||
model.model.eval()
|
||||
|
||||
if not self._model_prepare:
|
||||
# prepare model for pytorch
|
||||
|
||||
Reference in New Issue
Block a user