mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 04:29:22 +01:00
bad_image_detecting模型支持在数据集上validation功能
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11761935
This commit is contained in:
committed by
wenmeng.zwm
parent
058f772d16
commit
8411645524
@@ -70,6 +70,7 @@ task_default_metrics = {
|
||||
[Metrics.image_quality_assessment_degradation_metric],
|
||||
Tasks.image_quality_assessment_mos:
|
||||
[Metrics.image_quality_assessment_mos_metric],
|
||||
Tasks.bad_image_detecting: [Metrics.accuracy],
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from modelscope.metainfo import Models
|
||||
from modelscope.models.base import Tensor
|
||||
from modelscope.models.base.base_torch_model import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
@@ -52,15 +53,14 @@ class BadImageDetecting(TorchModel):
|
||||
|
||||
return {'output': ret}
|
||||
|
||||
def _evaluate_postprocess(self, input: Tensor,
|
||||
target: Tensor) -> Dict[str, list]:
|
||||
def _evaluate_postprocess(self, input: Tensor) -> Dict[str, list]:
|
||||
torch.cuda.empty_cache()
|
||||
with torch.no_grad():
|
||||
preds = self.model(input)
|
||||
_, pred_ = torch.max(preds, dim=1)
|
||||
del input
|
||||
torch.cuda.empty_cache()
|
||||
return {'pred': pred_, 'target': target}
|
||||
return {OutputKeys.LABEL: pred_}
|
||||
|
||||
def forward(self, inputs: Dict[str,
|
||||
Tensor]) -> Dict[str, Union[list, Tensor]]:
|
||||
@@ -74,7 +74,8 @@ class BadImageDetecting(TorchModel):
|
||||
"""
|
||||
if self.training:
|
||||
return self._train_forward(**inputs)
|
||||
elif 'target' in inputs:
|
||||
return self._evaluate_postprocess(**inputs)
|
||||
elif OutputKeys.LABEL in inputs:
|
||||
infeat = inputs['input']
|
||||
return self._evaluate_postprocess(infeat)
|
||||
else:
|
||||
return self._inference_forward(**inputs)
|
||||
|
||||
@@ -7,8 +7,9 @@ from modelscope.metainfo import Models
|
||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
||||
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
|
||||
TorchTaskDataset
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.preprocessors import LoadImage
|
||||
from modelscope.preprocessors.cv.bad_image_preprocessor import \
|
||||
from modelscope.preprocessors.cv.bad_image_detecting_preprocessor import \
|
||||
BadImageDetectingPreprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
@@ -31,9 +32,11 @@ class BadImageDetectingDataset(TorchTaskDataset):
|
||||
|
||||
# Load input video paths.
|
||||
item_dict = self.dataset[index]
|
||||
iterm_label = item_dict['label']
|
||||
|
||||
img = LoadImage.convert_to_ndarray(input)
|
||||
iterm_label = item_dict['category']
|
||||
# print(input)
|
||||
img = LoadImage.convert_to_ndarray(item_dict['image:FILE'])
|
||||
img = self.preprocessor(img)
|
||||
|
||||
return {'input': img['input'], 'target': iterm_label}
|
||||
return {
|
||||
'input': img['input'].squeeze(0),
|
||||
OutputKeys.LABEL: iterm_label
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user