mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-14 15:37:50 +01:00
Update network.py -- move hyperparameter to cpu -- li
This commit is contained in:
@@ -142,7 +142,7 @@ class XMem(nn.Module):
|
||||
if model_path is not None:
|
||||
# load the model and key/value/hidden dimensions with some hacks
|
||||
# config is updated with the loaded parameters
|
||||
model_weights = torch.load(model_path, map_location=map_location)
|
||||
model_weights = torch.load(model_path, map_location="cpu")
|
||||
self.key_dim = model_weights['key_proj.key_proj.weight'].shape[0]
|
||||
self.value_dim = model_weights['value_encoder.fuser.block2.conv2.weight'].shape[0]
|
||||
self.disable_hidden = 'decoder.hidden_update.transform.weight' not in model_weights
|
||||
|
||||
Reference in New Issue
Block a user