mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-15 16:07:51 +01:00
Update network.py -- move hyperparameter to cpu -- li
This commit is contained in:
@@ -142,7 +142,7 @@ class XMem(nn.Module):
|
|||||||
if model_path is not None:
|
if model_path is not None:
|
||||||
# load the model and key/value/hidden dimensions with some hacks
|
# load the model and key/value/hidden dimensions with some hacks
|
||||||
# config is updated with the loaded parameters
|
# config is updated with the loaded parameters
|
||||||
model_weights = torch.load(model_path, map_location=map_location)
|
model_weights = torch.load(model_path, map_location="cpu")
|
||||||
self.key_dim = model_weights['key_proj.key_proj.weight'].shape[0]
|
self.key_dim = model_weights['key_proj.key_proj.weight'].shape[0]
|
||||||
self.value_dim = model_weights['value_encoder.fuser.block2.conv2.weight'].shape[0]
|
self.value_dim = model_weights['value_encoder.fuser.block2.conv2.weight'].shape[0]
|
||||||
self.disable_hidden = 'decoder.hidden_update.transform.weight' not in model_weights
|
self.disable_hidden = 'decoder.hidden_update.transform.weight' not in model_weights
|
||||||
|
|||||||
Reference in New Issue
Block a user