diff --git a/tracker/model/trainer.py b/tracker/model/trainer.py index 05b4e19..817ed5f 100644 --- a/tracker/model/trainer.py +++ b/tracker/model/trainer.py @@ -195,7 +195,7 @@ class XMemTrainer: def load_checkpoint(self, path): # This method loads everything and should be used to resume training map_location = 'cuda:%d' % self.local_rank - checkpoint = torch.load(path, map_location={'cuda:0': map_location}) + checkpoint = torch.load(path, map_location={'cpu': map_location}) it = checkpoint['it'] network = checkpoint['network'] @@ -218,7 +218,7 @@ class XMemTrainer: def load_network(self, path): # This method loads only the network weight and should be used to load a pretrained model map_location = 'cuda:%d' % self.local_rank - src_dict = torch.load(path, map_location={'cuda:0': map_location}) + src_dict = torch.load(path, map_location={'cpu': map_location}) self.load_network_in_memory(src_dict) print(f'Network weight loaded from {path}')