mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
feat(trainer): support load_from for easycv trainer
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11794432
This commit is contained in:
@@ -4,6 +4,7 @@ from functools import partial
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from easycv.utils.checkpoint import load_checkpoint as ev_load_checkpoint
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
@@ -103,6 +104,16 @@ class EasyCVEpochBasedTrainer(EpochBasedTrainer):
|
||||
% h_i['type'])
|
||||
register_util.register_hook_to_ms(h_i['type'], self.logger)
|
||||
|
||||
# load pretrained model
|
||||
load_from = self.cfg.get('load_from', None)
|
||||
if load_from is not None:
|
||||
ev_load_checkpoint(
|
||||
self.model,
|
||||
filename=load_from,
|
||||
map_location=self.device,
|
||||
strict=False,
|
||||
)
|
||||
|
||||
# reset parallel
|
||||
if not self._dist:
|
||||
assert not is_parallel(
|
||||
|
||||
Reference in New Issue
Block a user