diff --git a/modelscope/models/base/base_head.py b/modelscope/models/base/base_head.py index 07a68253..11bda32f 100644 --- a/modelscope/models/base/base_head.py +++ b/modelscope/models/base/base_head.py @@ -1,6 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from abc import ABC, abstractmethod -from typing import Dict, Union +from typing import Any, Dict, Union from modelscope.models.base.base_model import Model from modelscope.utils.config import ConfigDict @@ -22,25 +22,20 @@ class Head(ABC): self.config = ConfigDict(kwargs) @abstractmethod - def forward(self, input: Input) -> Dict[str, Tensor]: + def forward(self, *args, **kwargs) -> Dict[str, Any]: """ This method will use the output from backbone model to do any - downstream tasks - Args: - input: The tensor output or a model from backbone model - (text generation need a model as input) - Returns: The output from downstream taks + downstream tasks. Recieve The output from backbone model. + + Returns (Dict[str, Any]): The output from downstream task. """ pass @abstractmethod - def compute_loss(self, outputs: Dict[str, Tensor], - labels) -> Dict[str, Tensor]: + def compute_loss(self, *args, **kwargs) -> Dict[str, Any]: """ - compute loss for head during the finetuning + compute loss for head during the finetuning. - Args: - outputs (Dict[str, Tensor]): the output from the model forward - Returns: the loss(Dict[str, Tensor]): + Returns (Dict[str, Any]): The loss dict """ pass diff --git a/modelscope/models/base/base_model.py b/modelscope/models/base/base_model.py index 872c42e8..8744ce1c 100644 --- a/modelscope/models/base/base_model.py +++ b/modelscope/models/base/base_model.py @@ -2,7 +2,7 @@ import os import os.path as osp from abc import ABC, abstractmethod -from typing import Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union from modelscope.hub.snapshot_download import snapshot_download from modelscope.models.builder import build_model @@ -10,8 +10,6 @@ from modelscope.utils.checkpoint import save_pretrained from modelscope.utils.config import Config from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile from modelscope.utils.device import device_placement, verify_device -from modelscope.utils.file_utils import func_receive_dict_inputs -from modelscope.utils.hub import parse_label_mapping from modelscope.utils.logger import get_logger logger = get_logger() @@ -27,35 +25,31 @@ class Model(ABC): verify_device(device_name) self._device_name = device_name - def __call__(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: - return self.postprocess(self.forward(input)) + def __call__(self, *args, **kwargs) -> Dict[str, Any]: + return self.postprocess(self.forward(*args, **kwargs)) @abstractmethod - def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + def forward(self, *args, **kwargs) -> Dict[str, Any]: """ Run the forward pass for a model. - Args: - input (Dict[str, Tensor]): the dict of the model inputs for the forward method - Returns: - Dict[str, Tensor]: output from the model forward pass + Dict[str, Any]: output from the model forward pass """ pass - def postprocess(self, input: Dict[str, Tensor], - **kwargs) -> Dict[str, Tensor]: + def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]: """ Model specific postprocess and convert model output to standard model outputs. Args: - input: input data + inputs: input data Return: dict of results: a dict containing outputs of model, each output should have the standard output name. """ - return input + return inputs @classmethod def _instantiate(cls, **kwargs): diff --git a/modelscope/models/base/base_torch_head.py b/modelscope/models/base/base_torch_head.py index c5a78519..faee4296 100644 --- a/modelscope/models/base/base_torch_head.py +++ b/modelscope/models/base/base_torch_head.py @@ -1,5 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from typing import Dict +from typing import Any, Dict import torch @@ -18,10 +18,8 @@ class TorchHead(Head, torch.nn.Module): super().__init__(**kwargs) torch.nn.Module.__init__(self) - def forward(self, inputs: Dict[str, - torch.Tensor]) -> Dict[str, torch.Tensor]: + def forward(self, *args, **kwargs) -> Dict[str, Any]: raise NotImplementedError - def compute_loss(self, outputs: Dict[str, torch.Tensor], - labels) -> Dict[str, torch.Tensor]: + def compute_loss(self, *args, **kwargs) -> Dict[str, Any]: raise NotImplementedError diff --git a/modelscope/models/base/base_torch_model.py b/modelscope/models/base/base_torch_model.py index cfc88721..3c99a1f2 100644 --- a/modelscope/models/base/base_torch_model.py +++ b/modelscope/models/base/base_torch_model.py @@ -1,6 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from typing import Any, Dict, Optional, Union +from typing import Any, Dict import torch from torch import nn @@ -21,15 +21,14 @@ class TorchModel(Model, torch.nn.Module): super().__init__(model_dir, *args, **kwargs) torch.nn.Module.__init__(self) - def __call__(self, input: Dict[str, - torch.Tensor]) -> Dict[str, torch.Tensor]: + def __call__(self, *args, **kwargs) -> Dict[str, Any]: + # Adapting a model with only one dict arg, and the arg name must be input or inputs if func_receive_dict_inputs(self.forward): - return self.postprocess(self.forward(input)) + return self.postprocess(self.forward(args[0], **kwargs)) else: - return self.postprocess(self.forward(**input)) + return self.postprocess(self.forward(*args, **kwargs)) - def forward(self, inputs: Dict[str, - torch.Tensor]) -> Dict[str, torch.Tensor]: + def forward(self, *args, **kwargs) -> Dict[str, Any]: raise NotImplementedError def post_init(self):