mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 16:37:58 +01:00
xmem load checkpoints cpu -- li
This commit is contained in:
@@ -195,7 +195,7 @@ class XMemTrainer:
|
|||||||
def load_checkpoint(self, path):
|
def load_checkpoint(self, path):
|
||||||
# This method loads everything and should be used to resume training
|
# This method loads everything and should be used to resume training
|
||||||
map_location = 'cuda:%d' % self.local_rank
|
map_location = 'cuda:%d' % self.local_rank
|
||||||
checkpoint = torch.load(path, map_location={'cuda:0': map_location})
|
checkpoint = torch.load(path, map_location={'cpu': map_location})
|
||||||
|
|
||||||
it = checkpoint['it']
|
it = checkpoint['it']
|
||||||
network = checkpoint['network']
|
network = checkpoint['network']
|
||||||
@@ -218,7 +218,7 @@ class XMemTrainer:
|
|||||||
def load_network(self, path):
|
def load_network(self, path):
|
||||||
# This method loads only the network weight and should be used to load a pretrained model
|
# This method loads only the network weight and should be used to load a pretrained model
|
||||||
map_location = 'cuda:%d' % self.local_rank
|
map_location = 'cuda:%d' % self.local_rank
|
||||||
src_dict = torch.load(path, map_location={'cuda:0': map_location})
|
src_dict = torch.load(path, map_location={'cpu': map_location})
|
||||||
|
|
||||||
self.load_network_in_memory(src_dict)
|
self.load_network_in_memory(src_dict)
|
||||||
print(f'Network weight loaded from {path}')
|
print(f'Network weight loaded from {path}')
|
||||||
|
|||||||
Reference in New Issue
Block a user