From 8c39eefeff778a6c53c9cc3cdad37da100eda2c5 Mon Sep 17 00:00:00 2001 From: "lee.lcy" Date: Wed, 1 Mar 2023 09:54:24 +0800 Subject: [PATCH] feat(trainer): support load_from for easycv trainer Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11794432 --- modelscope/trainers/easycv/trainer.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/modelscope/trainers/easycv/trainer.py b/modelscope/trainers/easycv/trainer.py index 978bea67..a1ad0649 100644 --- a/modelscope/trainers/easycv/trainer.py +++ b/modelscope/trainers/easycv/trainer.py @@ -4,6 +4,7 @@ from functools import partial from typing import Callable, Optional, Tuple, Union import torch +from easycv.utils.checkpoint import load_checkpoint as ev_load_checkpoint from torch import nn from torch.utils.data import Dataset @@ -103,6 +104,16 @@ class EasyCVEpochBasedTrainer(EpochBasedTrainer): % h_i['type']) register_util.register_hook_to_ms(h_i['type'], self.logger) + # load pretrained model + load_from = self.cfg.get('load_from', None) + if load_from is not None: + ev_load_checkpoint( + self.model, + filename=load_from, + map_location=self.device, + strict=False, + ) + # reset parallel if not self._dist: assert not is_parallel(