image_quality_assessment_mos use LoadImage for image io

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11596634
This commit is contained in:
zhongning.hzn
2023-02-09 03:37:27 +00:00
committed by wenmeng.zwm
parent 9d0559f034
commit e6bbde6ccb

View File

@@ -5,8 +5,6 @@ from typing import Any, Dict, Union
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import transforms
from modelscope.metainfo import Preprocessors
@@ -14,7 +12,6 @@ from modelscope.preprocessors import LoadImage
from modelscope.preprocessors.base import Preprocessor
from modelscope.preprocessors.builder import PREPROCESSORS
from modelscope.utils.constant import Fields
from modelscope.utils.hub import read_config
from modelscope.utils.type_assert import type_assert
@@ -29,18 +26,7 @@ class ImageQualityAssessmentMosPreprocessor(Preprocessor):
super().__init__(**kwargs)
def preprocessors(self, input):
if isinstance(input, str):
img = cv2.imread(input)
elif isinstance(input, PIL.Image.Image):
img = np.array(input.convert('RGB'))
elif isinstance(input, np.ndarray):
if len(input.shape) == 2:
img = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR)
else:
img = input
else:
raise TypeError(f'input should be either str, PIL.Image,'
f' np.array, but got {type(input)}')
img = LoadImage.convert_to_ndarray(input)
sub_img_dim = (720, 1280)
resize_dim = (1080, 1920)
h, w = img.shape[:2]
@@ -67,7 +53,6 @@ class ImageQualityAssessmentMosPreprocessor(Preprocessor):
if flag:
img = np.rot90(img)
img = img[:, :, ::-1]
img = LoadImage.convert_to_img(img)
test_transforms = transforms.Compose([transforms.ToTensor()])
img = test_transforms(img)