[to #42322933] fix video inpainting cpu inference

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10186204
This commit is contained in:
tingwei.gtw
2022-09-20 18:39:43 +08:00
committed by Yingda Chen
parent baed83b27d
commit 02f0f37134
3 changed files with 18 additions and 9 deletions

View File

@@ -1,4 +1,4 @@
# copyright (c) Alibaba, Inc. and its affiliates.
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule

View File

@@ -1,6 +1,6 @@
""" VideoInpaintingProcess
Base modules are adapted from https://github.com/researchmm/STTN,
originally Apache 2.0 License, Copyright (c) 2018-2022 OpenMMLab,
The implementation here is modified based on STTN,
originally Apache 2.0 License and publicly avaialbe at https://github.com/researchmm/STTN
"""
import os
@@ -243,7 +243,8 @@ def inpainting_by_model_balance(model, video_inputPath, mask_path,
for m in masks_temp
]
masks_temp = _to_tensors(masks_temp).unsqueeze(0)
feats_temp, masks_temp = feats_temp.cuda(), masks_temp.cuda()
if torch.cuda.is_available():
feats_temp, masks_temp = feats_temp.cuda(), masks_temp.cuda()
comp_frames = [None] * video_length
model.eval()
with torch.no_grad():

View File

@@ -1,15 +1,18 @@
""" VideoInpaintingNetwork
Base modules are adapted from https://github.com/researchmm/STTN,
originally Apache 2.0 License, Copyright (c) 2018-2022 OpenMMLab,
""" VideoInpaintingProcess
The implementation here is modified based on STTN,
originally Apache 2.0 License and publicly avaialbe at https://github.com/researchmm/STTN
"""
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from modelscope.metainfo import Models
from modelscope.models import Model
from modelscope.models.base import TorchModel
from modelscope.models.builder import MODELS
from modelscope.utils.constant import ModelFile, Tasks
@@ -84,8 +87,13 @@ class VideoInpainting(TorchModel):
super().__init__(
model_dir=model_dir, device_id=device_id, *args, **kwargs)
self.model = InpaintGenerator()
pretrained_params = torch.load('{}/{}'.format(
model_dir, ModelFile.TORCH_MODEL_BIN_FILE))
if torch.cuda.is_available():
device = 'cuda'
else:
device = 'cpu'
pretrained_params = torch.load(
'{}/{}'.format(model_dir, ModelFile.TORCH_MODEL_BIN_FILE),
map_location=device)
self.model.load_state_dict(pretrained_params['netG'])
self.model.eval()
self.device_id = device_id