mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 20:19:51 +01:00
[to #42322933] fix video inpainting cpu inference
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10186204
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
# copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
""" VideoInpaintingProcess
|
||||
Base modules are adapted from https://github.com/researchmm/STTN,
|
||||
originally Apache 2.0 License, Copyright (c) 2018-2022 OpenMMLab,
|
||||
The implementation here is modified based on STTN,
|
||||
originally Apache 2.0 License and publicly avaialbe at https://github.com/researchmm/STTN
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -243,7 +243,8 @@ def inpainting_by_model_balance(model, video_inputPath, mask_path,
|
||||
for m in masks_temp
|
||||
]
|
||||
masks_temp = _to_tensors(masks_temp).unsqueeze(0)
|
||||
feats_temp, masks_temp = feats_temp.cuda(), masks_temp.cuda()
|
||||
if torch.cuda.is_available():
|
||||
feats_temp, masks_temp = feats_temp.cuda(), masks_temp.cuda()
|
||||
comp_frames = [None] * video_length
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
|
||||
@@ -1,15 +1,18 @@
|
||||
""" VideoInpaintingNetwork
|
||||
Base modules are adapted from https://github.com/researchmm/STTN,
|
||||
originally Apache 2.0 License, Copyright (c) 2018-2022 OpenMMLab,
|
||||
""" VideoInpaintingProcess
|
||||
The implementation here is modified based on STTN,
|
||||
originally Apache 2.0 License and publicly avaialbe at https://github.com/researchmm/STTN
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision.models as models
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models import Model
|
||||
from modelscope.models.base import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
@@ -84,8 +87,13 @@ class VideoInpainting(TorchModel):
|
||||
super().__init__(
|
||||
model_dir=model_dir, device_id=device_id, *args, **kwargs)
|
||||
self.model = InpaintGenerator()
|
||||
pretrained_params = torch.load('{}/{}'.format(
|
||||
model_dir, ModelFile.TORCH_MODEL_BIN_FILE))
|
||||
if torch.cuda.is_available():
|
||||
device = 'cuda'
|
||||
else:
|
||||
device = 'cpu'
|
||||
pretrained_params = torch.load(
|
||||
'{}/{}'.format(model_dir, ModelFile.TORCH_MODEL_BIN_FILE),
|
||||
map_location=device)
|
||||
self.model.load_state_dict(pretrained_params['netG'])
|
||||
self.model.eval()
|
||||
self.device_id = device_id
|
||||
|
||||
Reference in New Issue
Block a user