mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 04:29:22 +01:00
refactor inputs format of model forward
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9673243 * refactor inputs format of model forward
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Union
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from modelscope.models.base.base_model import Model
|
||||
from modelscope.utils.config import ConfigDict
|
||||
@@ -22,25 +22,20 @@ class Head(ABC):
|
||||
self.config = ConfigDict(kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, input: Input) -> Dict[str, Tensor]:
|
||||
def forward(self, *args, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
This method will use the output from backbone model to do any
|
||||
downstream tasks
|
||||
Args:
|
||||
input: The tensor output or a model from backbone model
|
||||
(text generation need a model as input)
|
||||
Returns: The output from downstream taks
|
||||
downstream tasks. Recieve The output from backbone model.
|
||||
|
||||
Returns (Dict[str, Any]): The output from downstream task.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute_loss(self, outputs: Dict[str, Tensor],
|
||||
labels) -> Dict[str, Tensor]:
|
||||
def compute_loss(self, *args, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
compute loss for head during the finetuning
|
||||
compute loss for head during the finetuning.
|
||||
|
||||
Args:
|
||||
outputs (Dict[str, Tensor]): the output from the model forward
|
||||
Returns: the loss(Dict[str, Tensor]):
|
||||
Returns (Dict[str, Any]): The loss dict
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import os
|
||||
import os.path as osp
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models.builder import build_model
|
||||
@@ -10,8 +10,6 @@ from modelscope.utils.checkpoint import save_pretrained
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile
|
||||
from modelscope.utils.device import device_placement, verify_device
|
||||
from modelscope.utils.file_utils import func_receive_dict_inputs
|
||||
from modelscope.utils.hub import parse_label_mapping
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
@@ -27,35 +25,31 @@ class Model(ABC):
|
||||
verify_device(device_name)
|
||||
self._device_name = device_name
|
||||
|
||||
def __call__(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
return self.postprocess(self.forward(input))
|
||||
def __call__(self, *args, **kwargs) -> Dict[str, Any]:
|
||||
return self.postprocess(self.forward(*args, **kwargs))
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
def forward(self, *args, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Run the forward pass for a model.
|
||||
|
||||
Args:
|
||||
input (Dict[str, Tensor]): the dict of the model inputs for the forward method
|
||||
|
||||
Returns:
|
||||
Dict[str, Tensor]: output from the model forward pass
|
||||
Dict[str, Any]: output from the model forward pass
|
||||
"""
|
||||
pass
|
||||
|
||||
def postprocess(self, input: Dict[str, Tensor],
|
||||
**kwargs) -> Dict[str, Tensor]:
|
||||
def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
|
||||
""" Model specific postprocess and convert model output to
|
||||
standard model outputs.
|
||||
|
||||
Args:
|
||||
input: input data
|
||||
inputs: input data
|
||||
|
||||
Return:
|
||||
dict of results: a dict containing outputs of model, each
|
||||
output should have the standard output name.
|
||||
"""
|
||||
return input
|
||||
return inputs
|
||||
|
||||
@classmethod
|
||||
def _instantiate(cls, **kwargs):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
|
||||
@@ -18,10 +18,8 @@ class TorchHead(Head, torch.nn.Module):
|
||||
super().__init__(**kwargs)
|
||||
torch.nn.Module.__init__(self)
|
||||
|
||||
def forward(self, inputs: Dict[str,
|
||||
torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
def forward(self, *args, **kwargs) -> Dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
def compute_loss(self, outputs: Dict[str, torch.Tensor],
|
||||
labels) -> Dict[str, torch.Tensor]:
|
||||
def compute_loss(self, *args, **kwargs) -> Dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -21,15 +21,14 @@ class TorchModel(Model, torch.nn.Module):
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
torch.nn.Module.__init__(self)
|
||||
|
||||
def __call__(self, input: Dict[str,
|
||||
torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
def __call__(self, *args, **kwargs) -> Dict[str, Any]:
|
||||
# Adapting a model with only one dict arg, and the arg name must be input or inputs
|
||||
if func_receive_dict_inputs(self.forward):
|
||||
return self.postprocess(self.forward(input))
|
||||
return self.postprocess(self.forward(args[0], **kwargs))
|
||||
else:
|
||||
return self.postprocess(self.forward(**input))
|
||||
return self.postprocess(self.forward(*args, **kwargs))
|
||||
|
||||
def forward(self, inputs: Dict[str,
|
||||
torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
def forward(self, *args, **kwargs) -> Dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
def post_init(self):
|
||||
|
||||
Reference in New Issue
Block a user