refactor inputs format of model forward

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9673243

    * refactor inputs format of model forward
This commit is contained in:
jiangnana.jnn
2022-09-08 20:16:14 +08:00
parent be2f31fc15
commit 652ec697b7
4 changed files with 25 additions and 39 deletions

View File

@@ -1,6 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from abc import ABC, abstractmethod
from typing import Dict, Union
from typing import Any, Dict, Union
from modelscope.models.base.base_model import Model
from modelscope.utils.config import ConfigDict
@@ -22,25 +22,20 @@ class Head(ABC):
self.config = ConfigDict(kwargs)
@abstractmethod
def forward(self, input: Input) -> Dict[str, Tensor]:
def forward(self, *args, **kwargs) -> Dict[str, Any]:
"""
This method will use the output from backbone model to do any
downstream tasks
Args:
input: The tensor output or a model from backbone model
(text generation need a model as input)
Returns: The output from downstream taks
downstream tasks. Recieve The output from backbone model.
Returns (Dict[str, Any]): The output from downstream task.
"""
pass
@abstractmethod
def compute_loss(self, outputs: Dict[str, Tensor],
labels) -> Dict[str, Tensor]:
def compute_loss(self, *args, **kwargs) -> Dict[str, Any]:
"""
compute loss for head during the finetuning
compute loss for head during the finetuning.
Args:
outputs (Dict[str, Tensor]): the output from the model forward
Returns: the loss(Dict[str, Tensor]):
Returns (Dict[str, Any]): The loss dict
"""
pass

View File

@@ -2,7 +2,7 @@
import os
import os.path as osp
from abc import ABC, abstractmethod
from typing import Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models.builder import build_model
@@ -10,8 +10,6 @@ from modelscope.utils.checkpoint import save_pretrained
from modelscope.utils.config import Config
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile
from modelscope.utils.device import device_placement, verify_device
from modelscope.utils.file_utils import func_receive_dict_inputs
from modelscope.utils.hub import parse_label_mapping
from modelscope.utils.logger import get_logger
logger = get_logger()
@@ -27,35 +25,31 @@ class Model(ABC):
verify_device(device_name)
self._device_name = device_name
def __call__(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
return self.postprocess(self.forward(input))
def __call__(self, *args, **kwargs) -> Dict[str, Any]:
return self.postprocess(self.forward(*args, **kwargs))
@abstractmethod
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
def forward(self, *args, **kwargs) -> Dict[str, Any]:
"""
Run the forward pass for a model.
Args:
input (Dict[str, Tensor]): the dict of the model inputs for the forward method
Returns:
Dict[str, Tensor]: output from the model forward pass
Dict[str, Any]: output from the model forward pass
"""
pass
def postprocess(self, input: Dict[str, Tensor],
**kwargs) -> Dict[str, Tensor]:
def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
""" Model specific postprocess and convert model output to
standard model outputs.
Args:
input: input data
inputs: input data
Return:
dict of results: a dict containing outputs of model, each
output should have the standard output name.
"""
return input
return inputs
@classmethod
def _instantiate(cls, **kwargs):

View File

@@ -1,5 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Dict
from typing import Any, Dict
import torch
@@ -18,10 +18,8 @@ class TorchHead(Head, torch.nn.Module):
super().__init__(**kwargs)
torch.nn.Module.__init__(self)
def forward(self, inputs: Dict[str,
torch.Tensor]) -> Dict[str, torch.Tensor]:
def forward(self, *args, **kwargs) -> Dict[str, Any]:
raise NotImplementedError
def compute_loss(self, outputs: Dict[str, torch.Tensor],
labels) -> Dict[str, torch.Tensor]:
def compute_loss(self, *args, **kwargs) -> Dict[str, Any]:
raise NotImplementedError

View File

@@ -1,6 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict, Optional, Union
from typing import Any, Dict
import torch
from torch import nn
@@ -21,15 +21,14 @@ class TorchModel(Model, torch.nn.Module):
super().__init__(model_dir, *args, **kwargs)
torch.nn.Module.__init__(self)
def __call__(self, input: Dict[str,
torch.Tensor]) -> Dict[str, torch.Tensor]:
def __call__(self, *args, **kwargs) -> Dict[str, Any]:
# Adapting a model with only one dict arg, and the arg name must be input or inputs
if func_receive_dict_inputs(self.forward):
return self.postprocess(self.forward(input))
return self.postprocess(self.forward(args[0], **kwargs))
else:
return self.postprocess(self.forward(**input))
return self.postprocess(self.forward(*args, **kwargs))
def forward(self, inputs: Dict[str,
torch.Tensor]) -> Dict[str, torch.Tensor]:
def forward(self, *args, **kwargs) -> Dict[str, Any]:
raise NotImplementedError
def post_init(self):