xmem load checkpoints cpu -- li

This commit is contained in:
memoryunreal
2023-05-05 07:00:05 +00:00
parent 65588bb237
commit b1ca9273c8

View File

@@ -195,7 +195,7 @@ class XMemTrainer:
def load_checkpoint(self, path):
# This method loads everything and should be used to resume training
map_location = 'cuda:%d' % self.local_rank
checkpoint = torch.load(path, map_location={'cuda:0': map_location})
checkpoint = torch.load(path, map_location={'cpu': map_location})
it = checkpoint['it']
network = checkpoint['network']
@@ -218,7 +218,7 @@ class XMemTrainer:
def load_network(self, path):
# This method loads only the network weight and should be used to load a pretrained model
map_location = 'cuda:%d' % self.local_rank
src_dict = torch.load(path, map_location={'cuda:0': map_location})
src_dict = torch.load(path, map_location={'cpu': map_location})
self.load_network_in_memory(src_dict)
print(f'Network weight loaded from {path}')