mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-23 19:49:24 +01:00
30 lines
920 B
Python
30 lines
920 B
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
|
|
from abc import ABC, abstractmethod
|
|
from typing import Dict, List, Tuple, Union
|
|
|
|
Tensor = Union['torch.Tensor', 'tf.Tensor']
|
|
|
|
|
|
class Model(ABC):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
pass
|
|
|
|
def __call__(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
return self.post_process(self.forward(input))
|
|
|
|
@abstractmethod
|
|
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
pass
|
|
|
|
def post_process(self, input: Dict[str, Tensor],
|
|
**kwargs) -> Dict[str, Tensor]:
|
|
# model specific postprocess, implementation is optional
|
|
# will be called in Pipeline and evaluation loop(in the future)
|
|
return input
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, model_name_or_path: str, *model_args, **kwargs):
|
|
raise NotImplementedError('from_pretrained has not been implemented')
|