[to #42322933] fix cpu support

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9744999
This commit is contained in:
baiguan.yt
2022-08-13 16:35:47 +08:00
committed by Yingda Chen
parent a76b8de5d9
commit 2e8338fe73
3 changed files with 9 additions and 5 deletions

View File

@@ -45,8 +45,8 @@ class FQA(object):
model.load_state_dict(model_dict)
def get_face_quality(self, img):
img = torch.from_numpy(img).permute(2, 0,
1).unsqueeze(0).flip(1).cuda()
img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).flip(1).to(
self.device)
img = (img - 127.5) / 128.0
# extract features & predict quality

View File

@@ -36,7 +36,6 @@ class ImageColorizationPipeline(Pipeline):
self.device = torch.device('cuda')
else:
self.device = torch.device('cpu')
self.size = 1024
self.orig_img = None
self.model_type = 'stable'
@@ -91,6 +90,8 @@ class ImageColorizationPipeline(Pipeline):
img = LoadImage.convert_to_img(input).convert('LA').convert('RGB')
self.wide, self.height = img.size
if self.wide * self.height < 100000:
self.size = 256
self.orig_img = img.copy()
img = img.resize((self.size, self.size), resample=PIL.Image.BILINEAR)

View File

@@ -58,7 +58,8 @@ class ImagePortraitEnhancementPipeline(Pipeline):
gpen_model_path = f'{model}/{ModelFile.TORCH_MODEL_FILE}'
self.face_enhancer.load_state_dict(
torch.load(gpen_model_path), strict=True)
torch.load(gpen_model_path, map_location=torch.device('cpu')),
strict=True)
logger.info('load face enhancer model done')
@@ -82,7 +83,9 @@ class ImagePortraitEnhancementPipeline(Pipeline):
sr_model_path = f'{model}/super_resolution/realesrnet_x{self.scale}.pth'
self.sr_model.load_state_dict(
torch.load(sr_model_path)['params_ema'], strict=True)
torch.load(sr_model_path,
map_location=torch.device('cpu'))['params_ema'],
strict=True)
logger.info('load sr model done')