mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
fix bug for downloading hugging face pretrained model http error
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10750369
This commit is contained in:
@@ -31,7 +31,10 @@ class ReferringVideoObjectSegmentation(TorchModel):
|
||||
|
||||
config_path = osp.join(model_dir, ModelFile.CONFIGURATION)
|
||||
self.cfg = Config.from_file(config_path)
|
||||
self.model = MTTR(**self.cfg.model)
|
||||
transformer_cfg_dir = osp.join(model_dir, 'transformer_cfg_dir')
|
||||
|
||||
self.model = MTTR(
|
||||
transformer_cfg_dir=transformer_cfg_dir, **self.cfg.model)
|
||||
|
||||
model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE)
|
||||
params_dict = torch.load(model_path, map_location='cpu')
|
||||
|
||||
@@ -19,6 +19,7 @@ class MTTR(nn.Module):
|
||||
num_queries,
|
||||
mask_kernels_dim=8,
|
||||
aux_loss=False,
|
||||
transformer_cfg_dir=None,
|
||||
**kwargs):
|
||||
"""
|
||||
Parameters:
|
||||
@@ -29,7 +30,9 @@ class MTTR(nn.Module):
|
||||
"""
|
||||
super().__init__()
|
||||
self.backbone = init_backbone(**kwargs)
|
||||
self.transformer = MultimodalTransformer(**kwargs)
|
||||
assert transformer_cfg_dir is not None
|
||||
self.transformer = MultimodalTransformer(
|
||||
transformer_cfg_dir=transformer_cfg_dir, **kwargs)
|
||||
d_model = self.transformer.d_model
|
||||
self.is_referred_head = nn.Linear(
|
||||
d_model,
|
||||
|
||||
@@ -26,6 +26,7 @@ class MultimodalTransformer(nn.Module):
|
||||
num_decoder_layers=3,
|
||||
text_encoder_type='roberta-base',
|
||||
freeze_text_encoder=True,
|
||||
transformer_cfg_dir=None,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.d_model = kwargs['d_model']
|
||||
@@ -40,10 +41,12 @@ class MultimodalTransformer(nn.Module):
|
||||
self.pos_encoder_2d = PositionEmbeddingSine2D()
|
||||
self._reset_parameters()
|
||||
|
||||
self.text_encoder = RobertaModel.from_pretrained(text_encoder_type)
|
||||
if text_encoder_type != 'roberta-base':
|
||||
transformer_cfg_dir = text_encoder_type
|
||||
self.text_encoder = RobertaModel.from_pretrained(transformer_cfg_dir)
|
||||
self.text_encoder.pooler = None # this pooler is never used, this is a hack to avoid DDP problems...
|
||||
self.tokenizer = RobertaTokenizerFast.from_pretrained(
|
||||
text_encoder_type)
|
||||
transformer_cfg_dir)
|
||||
self.freeze_text_encoder = freeze_text_encoder
|
||||
if freeze_text_encoder:
|
||||
for p in self.text_encoder.parameters():
|
||||
|
||||
Reference in New Issue
Block a user