diff --git a/modelscope/models/cv/referring_video_object_segmentation/model.py b/modelscope/models/cv/referring_video_object_segmentation/model.py index 91f7ea91..29e702be 100644 --- a/modelscope/models/cv/referring_video_object_segmentation/model.py +++ b/modelscope/models/cv/referring_video_object_segmentation/model.py @@ -31,7 +31,10 @@ class ReferringVideoObjectSegmentation(TorchModel): config_path = osp.join(model_dir, ModelFile.CONFIGURATION) self.cfg = Config.from_file(config_path) - self.model = MTTR(**self.cfg.model) + transformer_cfg_dir = osp.join(model_dir, 'transformer_cfg_dir') + + self.model = MTTR( + transformer_cfg_dir=transformer_cfg_dir, **self.cfg.model) model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) params_dict = torch.load(model_path, map_location='cpu') diff --git a/modelscope/models/cv/referring_video_object_segmentation/utils/mttr.py b/modelscope/models/cv/referring_video_object_segmentation/utils/mttr.py index e603df6c..48d4bf70 100644 --- a/modelscope/models/cv/referring_video_object_segmentation/utils/mttr.py +++ b/modelscope/models/cv/referring_video_object_segmentation/utils/mttr.py @@ -19,6 +19,7 @@ class MTTR(nn.Module): num_queries, mask_kernels_dim=8, aux_loss=False, + transformer_cfg_dir=None, **kwargs): """ Parameters: @@ -29,7 +30,9 @@ class MTTR(nn.Module): """ super().__init__() self.backbone = init_backbone(**kwargs) - self.transformer = MultimodalTransformer(**kwargs) + assert transformer_cfg_dir is not None + self.transformer = MultimodalTransformer( + transformer_cfg_dir=transformer_cfg_dir, **kwargs) d_model = self.transformer.d_model self.is_referred_head = nn.Linear( d_model, diff --git a/modelscope/models/cv/referring_video_object_segmentation/utils/multimodal_transformer.py b/modelscope/models/cv/referring_video_object_segmentation/utils/multimodal_transformer.py index 39962715..f750437a 100644 --- a/modelscope/models/cv/referring_video_object_segmentation/utils/multimodal_transformer.py +++ b/modelscope/models/cv/referring_video_object_segmentation/utils/multimodal_transformer.py @@ -26,6 +26,7 @@ class MultimodalTransformer(nn.Module): num_decoder_layers=3, text_encoder_type='roberta-base', freeze_text_encoder=True, + transformer_cfg_dir=None, **kwargs): super().__init__() self.d_model = kwargs['d_model'] @@ -40,10 +41,12 @@ class MultimodalTransformer(nn.Module): self.pos_encoder_2d = PositionEmbeddingSine2D() self._reset_parameters() - self.text_encoder = RobertaModel.from_pretrained(text_encoder_type) + if text_encoder_type != 'roberta-base': + transformer_cfg_dir = text_encoder_type + self.text_encoder = RobertaModel.from_pretrained(transformer_cfg_dir) self.text_encoder.pooler = None # this pooler is never used, this is a hack to avoid DDP problems... self.tokenizer = RobertaTokenizerFast.from_pretrained( - text_encoder_type) + transformer_cfg_dir) self.freeze_text_encoder = freeze_text_encoder if freeze_text_encoder: for p in self.text_encoder.parameters():