add model name to baseModel. use model name as tag

This commit is contained in:
jiangyu.xzy
2022-11-01 15:31:08 +08:00
parent 60af6b701b
commit cc76d900bc
4 changed files with 25 additions and 6 deletions

16
modelscope/hub/t_jy.py Normal file
View File

@@ -0,0 +1,16 @@
def dec(param1):
print(param1)
def in_dec(func):
def in_func(name):
return func(name)
return in_func
return in_dec
@dec("dec1")
def aa(param):
print(param)
return
aa("heell")

View File

@@ -131,6 +131,8 @@ class Model(ABC):
if not hasattr(model, 'cfg'):
model.cfg = cfg
model.name = model_name_or_path
return model
def save_pretrained(self,

View File

@@ -152,8 +152,9 @@ class Pipeline(ABC):
**kwargs) -> Union[Dict[str, Any], Generator]:
# model provider should leave it as it is
# modelscope library developer will handle this function
model_name = self.cfg.model.type
create_library_statistics("pipeline", model_name, None)
for single_model in self.models:
if hasattr(single_model, 'name'):
create_library_statistics("pipeline", single_model.name, None)
# place model to cpu or gpu
if (self.model or (self.has_multiple_models and self.models[0])):
if not self._model_prepare:

View File

@@ -437,8 +437,8 @@ class EpochBasedTrainer(BaseTrainer):
def train(self, checkpoint_path=None, *args, **kwargs):
self._mode = ModeKeys.TRAIN
model_name = self.cfg.model.type
create_library_statistics("train", model_name, None)
if hasattr(self.model, 'name'):
create_library_statistics("train", self.model.name, None)
if self.train_dataset is None:
self.train_dataloader = self.get_train_dataloader()
@@ -459,8 +459,8 @@ class EpochBasedTrainer(BaseTrainer):
self.train_loop(self.train_dataloader)
def evaluate(self, checkpoint_path=None):
model_name = self.cfg.model.type
create_library_statistics("evaluate", model_name, None)
if hasattr(self.model, 'name'):
create_library_statistics("evaluate", self.model.name, None)
if checkpoint_path is not None and os.path.isfile(checkpoint_path):
from modelscope.trainers.hooks import CheckpointHook
CheckpointHook.load_checkpoint(checkpoint_path, self)