Files
modelscope/maas_lib/models/base.py

30 lines
920 B
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
from abc import ABC, abstractmethod
from typing import Dict, List, Tuple, Union
Tensor = Union['torch.Tensor', 'tf.Tensor']
class Model(ABC):
def __init__(self, *args, **kwargs):
pass
def __call__(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
return self.post_process(self.forward(input))
@abstractmethod
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
pass
def post_process(self, input: Dict[str, Tensor],
**kwargs) -> Dict[str, Tensor]:
# model specific postprocess, implementation is optional
# will be called in Pipeline and evaluation loop(in the future)
return input
@classmethod
def from_pretrained(cls, model_name_or_path: str, *model_args, **kwargs):
raise NotImplementedError('from_pretrained has not been implemented')