From b1ca9273c80475e9d784466e470c77da7f8698dd Mon Sep 17 00:00:00 2001 From: memoryunreal <814514103@qq.com> Date: Fri, 5 May 2023 07:00:05 +0000 Subject: [PATCH] xmem load checkpoints cpu -- li --- tracker/model/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tracker/model/trainer.py b/tracker/model/trainer.py index 05b4e19..817ed5f 100644 --- a/tracker/model/trainer.py +++ b/tracker/model/trainer.py @@ -195,7 +195,7 @@ class XMemTrainer: def load_checkpoint(self, path): # This method loads everything and should be used to resume training map_location = 'cuda:%d' % self.local_rank - checkpoint = torch.load(path, map_location={'cuda:0': map_location}) + checkpoint = torch.load(path, map_location={'cpu': map_location}) it = checkpoint['it'] network = checkpoint['network'] @@ -218,7 +218,7 @@ class XMemTrainer: def load_network(self, path): # This method loads only the network weight and should be used to load a pretrained model map_location = 'cuda:%d' % self.local_rank - src_dict = torch.load(path, map_location={'cuda:0': map_location}) + src_dict = torch.load(path, map_location={'cpu': map_location}) self.load_network_in_memory(src_dict) print(f'Network weight loaded from {path}')