mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 20:19:51 +01:00
[to #42322933] fix cpu support
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9744999
This commit is contained in:
@@ -45,8 +45,8 @@ class FQA(object):
|
||||
model.load_state_dict(model_dict)
|
||||
|
||||
def get_face_quality(self, img):
|
||||
img = torch.from_numpy(img).permute(2, 0,
|
||||
1).unsqueeze(0).flip(1).cuda()
|
||||
img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).flip(1).to(
|
||||
self.device)
|
||||
img = (img - 127.5) / 128.0
|
||||
|
||||
# extract features & predict quality
|
||||
|
||||
@@ -36,7 +36,6 @@ class ImageColorizationPipeline(Pipeline):
|
||||
self.device = torch.device('cuda')
|
||||
else:
|
||||
self.device = torch.device('cpu')
|
||||
self.size = 1024
|
||||
|
||||
self.orig_img = None
|
||||
self.model_type = 'stable'
|
||||
@@ -91,6 +90,8 @@ class ImageColorizationPipeline(Pipeline):
|
||||
img = LoadImage.convert_to_img(input).convert('LA').convert('RGB')
|
||||
|
||||
self.wide, self.height = img.size
|
||||
if self.wide * self.height < 100000:
|
||||
self.size = 256
|
||||
self.orig_img = img.copy()
|
||||
img = img.resize((self.size, self.size), resample=PIL.Image.BILINEAR)
|
||||
|
||||
|
||||
@@ -58,7 +58,8 @@ class ImagePortraitEnhancementPipeline(Pipeline):
|
||||
|
||||
gpen_model_path = f'{model}/{ModelFile.TORCH_MODEL_FILE}'
|
||||
self.face_enhancer.load_state_dict(
|
||||
torch.load(gpen_model_path), strict=True)
|
||||
torch.load(gpen_model_path, map_location=torch.device('cpu')),
|
||||
strict=True)
|
||||
|
||||
logger.info('load face enhancer model done')
|
||||
|
||||
@@ -82,7 +83,9 @@ class ImagePortraitEnhancementPipeline(Pipeline):
|
||||
|
||||
sr_model_path = f'{model}/super_resolution/realesrnet_x{self.scale}.pth'
|
||||
self.sr_model.load_state_dict(
|
||||
torch.load(sr_model_path)['params_ema'], strict=True)
|
||||
torch.load(sr_model_path,
|
||||
map_location=torch.device('cpu'))['params_ema'],
|
||||
strict=True)
|
||||
|
||||
logger.info('load sr model done')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user