feat(trainer): support load_from for easycv trainer

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11794432
This commit is contained in:
lee.lcy
2023-03-01 09:54:24 +08:00
committed by wenmeng.zwm
parent ceeb85f10f
commit 8c39eefeff

View File

@@ -4,6 +4,7 @@ from functools import partial
from typing import Callable, Optional, Tuple, Union
import torch
from easycv.utils.checkpoint import load_checkpoint as ev_load_checkpoint
from torch import nn
from torch.utils.data import Dataset
@@ -103,6 +104,16 @@ class EasyCVEpochBasedTrainer(EpochBasedTrainer):
% h_i['type'])
register_util.register_hook_to_ms(h_i['type'], self.logger)
# load pretrained model
load_from = self.cfg.get('load_from', None)
if load_from is not None:
ev_load_checkpoint(
self.model,
filename=load_from,
map_location=self.device,
strict=False,
)
# reset parallel
if not self._dist:
assert not is_parallel(