mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-15 16:07:51 +01:00
xmem load checkpoints cpu -- li
This commit is contained in:
@@ -195,7 +195,7 @@ class XMemTrainer:
|
||||
def load_checkpoint(self, path):
|
||||
# This method loads everything and should be used to resume training
|
||||
map_location = 'cuda:%d' % self.local_rank
|
||||
checkpoint = torch.load(path, map_location={'cuda:0': map_location})
|
||||
checkpoint = torch.load(path, map_location={'cpu': map_location})
|
||||
|
||||
it = checkpoint['it']
|
||||
network = checkpoint['network']
|
||||
@@ -218,7 +218,7 @@ class XMemTrainer:
|
||||
def load_network(self, path):
|
||||
# This method loads only the network weight and should be used to load a pretrained model
|
||||
map_location = 'cuda:%d' % self.local_rank
|
||||
src_dict = torch.load(path, map_location={'cuda:0': map_location})
|
||||
src_dict = torch.load(path, map_location={'cpu': map_location})
|
||||
|
||||
self.load_network_in_memory(src_dict)
|
||||
print(f'Network weight loaded from {path}')
|
||||
|
||||
Reference in New Issue
Block a user