mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 20:19:51 +01:00
fix video output of image2video (#488)
* fix video output * fix logger.error * fix log error
This commit is contained in:
@@ -1,13 +1,16 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import collections
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
@@ -348,7 +351,7 @@ class Decoder(nn.Module):
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2**(self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
logging.info('Working with z of shape {} = {} dimensions.'.format(
|
||||
logger.info('Working with z of shape {} = {} dimensions.'.format(
|
||||
self.z_shape, np.prod(self.z_shape)))
|
||||
|
||||
# z to block_in
|
||||
@@ -478,7 +481,7 @@ class AutoencoderKL(nn.Module):
|
||||
k_new = k.split('first_stage_model.')[-1]
|
||||
sd_new[k_new] = sd[k]
|
||||
self.load_state_dict(sd_new, strict=True)
|
||||
logging.info(f'Restored from {path}')
|
||||
logger.info(f'Restored from {path}')
|
||||
|
||||
def on_train_batch_end(self, *args, **kwargs):
|
||||
if self.use_ema:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
@@ -72,14 +73,20 @@ class ImageToVideoPipeline(Pipeline):
|
||||
output_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
|
||||
temp_video_file = True
|
||||
|
||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||
h, w, c = video[0].shape
|
||||
video_writer = cv2.VideoWriter(
|
||||
output_video_path, fourcc, fps=8, frameSize=(w, h))
|
||||
for i in range(len(video)):
|
||||
img = cv2.cvtColor(video[i], cv2.COLOR_RGB2BGR)
|
||||
video_writer.write(img)
|
||||
video_writer.release()
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
for fid, frame in enumerate(video):
|
||||
tpth = os.path.join(temp_dir, '%06d.png' % (fid + 1))
|
||||
cv2.imwrite(tpth, frame[:, :, ::-1],
|
||||
[int(cv2.IMWRITE_JPEG_QUALITY), 100])
|
||||
|
||||
cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate 8.0 -i {temp_dir}/%06d.png \
|
||||
-vcodec libx264 -crf 17 -pix_fmt yuv420p {output_video_path}'
|
||||
|
||||
status, output = subprocess.getstatusoutput(cmd)
|
||||
if status != 0:
|
||||
logger.error('Save Video Error with {}'.format(output))
|
||||
os.system(f'rm -rf {temp_dir}')
|
||||
|
||||
if temp_video_file:
|
||||
video_file_content = b''
|
||||
with open(output_video_path, 'rb') as f:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
@@ -111,9 +112,9 @@ class VideoToVideoPipeline(Pipeline):
|
||||
cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate 8.0 -i {temp_dir}/%06d.png \
|
||||
-vcodec libx264 -crf 17 -pix_fmt yuv420p {output_video_path}'
|
||||
|
||||
status = os.system(cmd)
|
||||
status, output = subprocess.getstatusoutput(cmd)
|
||||
if status != 0:
|
||||
logger.info('Save Video Error with {}'.format(status))
|
||||
logger.error('Save Video Error with {}'.format(output))
|
||||
os.system(f'rm -rf {temp_dir}')
|
||||
|
||||
if temp_video_file:
|
||||
|
||||
@@ -4,7 +4,7 @@ import unittest
|
||||
from modelscope.models import Model
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import DownloadMode, Tasks
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user