[to #42322933]fix UT error for 830 version

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10142442
This commit is contained in:
wenmeng.zwm
2022-09-16 22:42:39 +08:00
parent 9503e3903f
commit d5c5c64dc8
17 changed files with 51 additions and 66 deletions

View File

@@ -18,21 +18,8 @@
from __future__ import absolute_import, division, print_function
import copy
import logging
import math
import os
import shutil
import tarfile
import tempfile
from pathlib import Path
from typing import Union
import json
import numpy as np
import torch
import torch_scatter
from icecream import ic
from torch import nn
from torch.nn import CrossEntropyLoss
logger = logging.getLogger(__name__)

View File

@@ -17,21 +17,15 @@
from __future__ import absolute_import, division, print_function
import copy
import logging
import math
import os
import shutil
import tarfile
import tempfile
from pathlib import Path
from typing import Union
import json
import numpy as np
import torch
import torch_scatter
from torch import nn
from torch.nn import CrossEntropyLoss
from modelscope.models.nlp.star3.configuration_star3 import Star3Config
from modelscope.utils.constant import ModelFile
@@ -121,33 +115,17 @@ class BertEmbeddings(nn.Module):
words_embeddings = self.word_embeddings(input_ids)
header_embeddings = self.word_embeddings(header_ids)
# header mean pooling
header_flatten_embeddings = self.word_embeddings(header_flatten_tokens)
header_flatten_index = header_flatten_index.reshape(
(-1, header_flatten_index.shape[1], 1))
header_flatten_index = header_flatten_index.repeat(
1, 1, header_flatten_embeddings.shape[2])
header_flatten_output = header_flatten_output.reshape(
(-1, header_flatten_output.shape[1], 1))
header_flatten_output = header_flatten_output.repeat(
1, 1, header_flatten_embeddings.shape[2])
header_embeddings = torch_scatter.scatter_mean(
header_flatten_embeddings,
header_flatten_index,
out=header_flatten_output,
dim=1)
token_column_id = token_column_id.reshape(
(-1, token_column_id.shape[1], 1))
token_column_id = token_column_id.repeat(
(1, 1, header_embeddings.shape[2]))
token_column_mask = token_column_mask.reshape(
(-1, token_column_mask.shape[1], 1))
token_column_mask = token_column_mask.repeat(
(1, 1, header_embeddings.shape[2]))
token_header_embeddings = torch.gather(header_embeddings, 1,
token_column_id)
words_embeddings = words_embeddings * (1.0 - token_column_mask) + \
token_header_embeddings * token_column_mask
if col_dict_list is not None and l_hs is not None:
col_dict_list = np.array(col_dict_list)[ids.cpu().numpy()].tolist()
header_len = np.array(
header_len, dtype=object)[ids.cpu().numpy()].tolist()
for bi, col_dict in enumerate(col_dict_list):
for ki, vi in col_dict.items():
length = header_len[bi][vi]
if length == 0:
continue
words_embeddings[bi, ki, :] = torch.mean(
header_embeddings[bi, vi, :length, :], dim=0)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)

View File

@@ -1,11 +1,10 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from typing import Dict, Optional
from typing import Dict
import numpy
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertTokenizer
@@ -15,7 +14,6 @@ from modelscope.models.builder import MODELS
from modelscope.models.nlp.star3.configuration_star3 import Star3Config
from modelscope.models.nlp.star3.modeling_star3 import Seq2SQL, Star3Model
from modelscope.preprocessors.star3.fields.struct import Constant
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.device import verify_device

View File

@@ -48,7 +48,7 @@ class SequenceClassificationModel(SingleBackboneTaskModelBase):
self.build_backbone(backbone_cfg)
self.build_head(head_cfg)
def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]:
def forward(self, **input: Dict[str, Any]) -> Dict[str, np.ndarray]:
outputs = super().forward(input)
sequence_output, pooled_output = self.extract_backbone_outputs(outputs)
outputs = self.head.forward(pooled_output)

View File

@@ -101,7 +101,7 @@ class FillMaskPipeline(Pipeline):
def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
with torch.no_grad():
return self.model(inputs, **forward_params)
return self.model(**inputs, **forward_params)
def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""process the prediction results

View File

@@ -97,7 +97,7 @@ class FillMaskPonetPipeline(Pipeline):
def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
with torch.no_grad():
return self.model(inputs, **forward_params)
return self.model(**inputs, **forward_params)
def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""process the prediction results

View File

@@ -35,7 +35,7 @@ class SequenceClassificationPipelineBase(Pipeline):
def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
with torch.no_grad():
return self.model(inputs, **forward_params)
return self.model(**inputs, **forward_params)
def postprocess(self,
inputs: Dict[str, Any],

View File

@@ -2,7 +2,6 @@
import os
from typing import Any, Dict, Union
import torch
from transformers import BertTokenizer
from modelscope.metainfo import Pipelines
@@ -88,7 +87,7 @@ class TableQuestionAnsweringPipeline(Pipeline):
return current_sql
elif action == 'del_focus':
pre_final_sql = copy.deepcopy(history_sql)
pre_final_sql = history_sql
pre_sels = []
pre_aggs = []
for idx, seli in enumerate(pre_final_sql['sel']):
@@ -151,7 +150,7 @@ class TableQuestionAnsweringPipeline(Pipeline):
return pre_final_sql
elif action == 'del_cond':
pre_final_sql = copy.deepcopy(history_sql)
pre_final_sql = history_sql
final_conds = []

View File

@@ -85,7 +85,7 @@ class ZeroShotClassificationPipeline(Pipeline):
def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
with torch.no_grad():
return self.model(inputs, **forward_params)
return self.model(**inputs, **forward_params)
def postprocess(self,
inputs: Dict[str, Any],

View File

@@ -0,0 +1,19 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .utils import AddLrLogHook, EasyCVMetric
else:
_import_structure = {'utils': ['AddLrLogHook', 'EasyCVMetric']}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -5,7 +5,7 @@ from typing import Callable, Dict, Optional, Tuple, Union
import numpy as np
from modelscope.metainfo import Trainers
from modelscope.models.nlp.space.model.generator import Generator
from modelscope.models.nlp.space.model.generator import SpaceGenerator
from modelscope.models.nlp.space.model.model_base import SpaceModelBase
from modelscope.preprocessors.space.data_loader import \
get_sequential_data_loader
@@ -90,7 +90,7 @@ class DialogIntentTrainer(BaseTrainer):
data_type='test')
# set generator
generator = Generator.create(self.cfg, reader=bpe)
generator = SpaceGenerator.create(self.cfg, reader=bpe)
# construct model
self.model = SpaceModelBase.create(
self.cfg.Model.init_checkpoint,

View File

@@ -542,7 +542,7 @@ class EpochBasedTrainer(BaseTrainer):
value = train_outputs.get(key, None)
if value is not None:
if dist.is_available() and dist.is_initialized():
value = value.data.clone()
value = value.data.clone().to('cuda')
dist.all_reduce(value.div_(dist.get_world_size()))
log_vars.update({key: value.item()})
self.log_buffer.update(log_vars)

View File

@@ -293,6 +293,9 @@ class AstScaning(object):
if type(attribute_node).__name__ == 'Str':
result.append((getattr(node,
'arg'), attribute_node.s, None))
elif type(attribute_node).__name__ == 'Constant':
result.append(
(getattr(node, 'arg'), attribute_node.value, None))
else:
result.append((getattr(node, 'arg'), )
+ _get_attribute_item(attribute_node))

View File

@@ -1,4 +1,3 @@
import os.path as osp
from typing import List
from modelscope.outputs import OutputKeys

View File

@@ -1,7 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import unittest
from typing import List
from transformers import BertTokenizer

View File

@@ -6,6 +6,9 @@ isolated: # test cases that may require excessive anmount of GPU memory, which
- test_video_summarization.py
- test_dialog_modeling.py
- test_csanmt_translation.py
- test_image_super_resolution.py
- test_easycv_trainer.py
- test_segformer.py
envs:
default: # default env, case not in other env will in default, pytorch.

View File

@@ -31,11 +31,11 @@ class EasyCVTrainerTestSegformer(unittest.TestCase):
shutil.rmtree(self.tmp_dir, ignore_errors=True)
def _train(self):
# adapt to distributed mode
from easycv.utils.test_util import pseudo_dist_init
pseudo_dist_init()
cfg_options = {'train.max_epochs': 2}
cfg_options = {
'train.max_epochs': 2,
'model.decode_head.norm_cfg.type': 'BN'
}
trainer_name = Trainers.easycv
train_dataset = MsDataset.load(