fix bug for downloading hugging face pretrained model http error

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10750369
This commit is contained in:
shuying.shu
2022-11-16 20:30:03 +08:00
committed by yingda.chen
parent 3798677395
commit 10926a06d4
3 changed files with 13 additions and 4 deletions

View File

@@ -31,7 +31,10 @@ class ReferringVideoObjectSegmentation(TorchModel):
config_path = osp.join(model_dir, ModelFile.CONFIGURATION)
self.cfg = Config.from_file(config_path)
self.model = MTTR(**self.cfg.model)
transformer_cfg_dir = osp.join(model_dir, 'transformer_cfg_dir')
self.model = MTTR(
transformer_cfg_dir=transformer_cfg_dir, **self.cfg.model)
model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE)
params_dict = torch.load(model_path, map_location='cpu')

View File

@@ -19,6 +19,7 @@ class MTTR(nn.Module):
num_queries,
mask_kernels_dim=8,
aux_loss=False,
transformer_cfg_dir=None,
**kwargs):
"""
Parameters:
@@ -29,7 +30,9 @@ class MTTR(nn.Module):
"""
super().__init__()
self.backbone = init_backbone(**kwargs)
self.transformer = MultimodalTransformer(**kwargs)
assert transformer_cfg_dir is not None
self.transformer = MultimodalTransformer(
transformer_cfg_dir=transformer_cfg_dir, **kwargs)
d_model = self.transformer.d_model
self.is_referred_head = nn.Linear(
d_model,

View File

@@ -26,6 +26,7 @@ class MultimodalTransformer(nn.Module):
num_decoder_layers=3,
text_encoder_type='roberta-base',
freeze_text_encoder=True,
transformer_cfg_dir=None,
**kwargs):
super().__init__()
self.d_model = kwargs['d_model']
@@ -40,10 +41,12 @@ class MultimodalTransformer(nn.Module):
self.pos_encoder_2d = PositionEmbeddingSine2D()
self._reset_parameters()
self.text_encoder = RobertaModel.from_pretrained(text_encoder_type)
if text_encoder_type != 'roberta-base':
transformer_cfg_dir = text_encoder_type
self.text_encoder = RobertaModel.from_pretrained(transformer_cfg_dir)
self.text_encoder.pooler = None # this pooler is never used, this is a hack to avoid DDP problems...
self.tokenizer = RobertaTokenizerFast.from_pretrained(
text_encoder_type)
transformer_cfg_dir)
self.freeze_text_encoder = freeze_text_encoder
if freeze_text_encoder:
for p in self.text_encoder.parameters():