Update network.py -- move hyperparameter to cpu -- li

This commit is contained in:
Zhe Li
2023-06-05 17:47:29 +08:00
committed by GitHub
parent bf9128480c
commit e6e1592737

View File

@@ -142,7 +142,7 @@ class XMem(nn.Module):
if model_path is not None:
# load the model and key/value/hidden dimensions with some hacks
# config is updated with the loaded parameters
model_weights = torch.load(model_path, map_location=map_location)
model_weights = torch.load(model_path, map_location="cpu")
self.key_dim = model_weights['key_proj.key_proj.weight'].shape[0]
self.value_dim = model_weights['value_encoder.fuser.block2.conv2.weight'].shape[0]
self.disable_hidden = 'decoder.hidden_update.transform.weight' not in model_weights