mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 12:10:09 +01:00
[to #42322933] add fintune support for cartoon task
人像卡通化模型增加训练支持 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11675597 * add fintune support for cartoon
This commit is contained in:
@@ -5,6 +5,7 @@ from .base import Exporter
|
||||
from .builder import build_exporter
|
||||
|
||||
if is_tf_available():
|
||||
from .cv import CartoonTranslationExporter
|
||||
from .nlp import CsanmtForTranslationExporter
|
||||
from .tf_model_exporter import TfModelExporter
|
||||
if is_torch_available():
|
||||
|
||||
6
modelscope/exporters/cv/__init__.py
Normal file
6
modelscope/exporters/cv/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from modelscope.utils.import_utils import is_tf_available
|
||||
|
||||
if is_tf_available():
|
||||
from .cartoon_translation_exporter import CartoonTranslationExporter
|
||||
72
modelscope/exporters/cv/cartoon_translation_exporter.py
Normal file
72
modelscope/exporters/cv/cartoon_translation_exporter.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
import tensorflow as tf
|
||||
from packaging import version
|
||||
|
||||
from modelscope.exporters.builder import EXPORTERS
|
||||
from modelscope.exporters.tf_model_exporter import TfModelExporter
|
||||
from modelscope.models.cv.cartoon import CartoonModel
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
if version.parse(tf.__version__) < version.parse('2'):
|
||||
pass
|
||||
else:
|
||||
logger.info(
|
||||
f'TensorFlow version {_tf_version} found, TF2.x is not supported by CartoonTranslationExporter.'
|
||||
)
|
||||
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
|
||||
|
||||
@EXPORTERS.register_module(module_name=r'cartoon-translation')
|
||||
class CartoonTranslationExporter(TfModelExporter):
|
||||
|
||||
def __init__(self, model=None):
|
||||
super().__init__(model)
|
||||
|
||||
def export_frozen_graph_def(self, ckpt_path: str, frozen_graph_path: str,
|
||||
**kwargs):
|
||||
tf.get_variable_scope().reuse_variables()
|
||||
|
||||
input = tf.placeholder(tf.float32, [None, None, 3], name='input_image')
|
||||
input = input[:, :, :][tf.newaxis]
|
||||
input = input / 127.5 - 1.0
|
||||
|
||||
model = CartoonModel(model_dir='')
|
||||
output = model(input)
|
||||
final_out = output['output_cartoon'][0]
|
||||
final_out = tf.clip_by_value(final_out, -0.999999, 0.999999)
|
||||
final_out = (final_out + 1.0) * 127.5
|
||||
final_out = tf.cast(final_out, tf.uint8, name='output_image')
|
||||
|
||||
all_vars = tf.trainable_variables()
|
||||
gene_vars = [var for var in all_vars if 'generator' in var.name]
|
||||
saver = tf.train.Saver(var_list=gene_vars)
|
||||
|
||||
init = tf.global_variables_initializer()
|
||||
config = tf.ConfigProto(allow_soft_placement=True)
|
||||
config.gpu_options.allow_growth = True
|
||||
with tf.Session(config=config) as sess:
|
||||
|
||||
sess.run(init)
|
||||
saver.restore(sess, ckpt_path)
|
||||
frozen_graph_def = tf.graph_util.convert_variables_to_constants(
|
||||
sess, sess.graph_def, output_node_names=['output_image'])
|
||||
with open(frozen_graph_path, 'wb') as f:
|
||||
f.write(frozen_graph_def.SerializeToString())
|
||||
print('freeze done')
|
||||
|
||||
return {'model': frozen_graph_path}
|
||||
|
||||
def export_saved_model(self, output_dir: str, **kwargs):
|
||||
raise NotImplementedError(
|
||||
'Exporting saved model is not supported by CartoonTranslationExporter currently.'
|
||||
)
|
||||
|
||||
def export_onnx(self, output_dir: str, **kwargs):
|
||||
raise NotImplementedError(
|
||||
'Exporting onnx model is not supported by CartoonTranslationExporter currently.'
|
||||
)
|
||||
@@ -7,14 +7,24 @@ if TYPE_CHECKING:
|
||||
from .facelib.facer import FaceAna
|
||||
from .mtcnn_pytorch.src.align_trans import (get_reference_facial_points,
|
||||
warp_and_crop_face)
|
||||
from .utils import (get_f5p, padTo16x, resize_size)
|
||||
from .utils import (get_f5p, padTo16x, resize_size, all_file,
|
||||
tf_data_loader, write_batch_image)
|
||||
from .network import disc_sn
|
||||
from .loss import simple_superpixel
|
||||
from .model_tf import CartoonModel
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'facelib.facer': ['FaceAna'],
|
||||
'mtcnn_pytorch.src.align_trans':
|
||||
['get_reference_facial_points', 'warp_and_crop_face'],
|
||||
'utils': ['get_f5p', 'padTo16x', 'resize_size']
|
||||
'utils': [
|
||||
'get_f5p', 'padTo16x', 'resize_size', 'all_file', 'tf_data_loader',
|
||||
'write_batch_image'
|
||||
],
|
||||
'network': ['disc_sn'],
|
||||
'loss': ['simple_superpixel'],
|
||||
'model_tf': ['CartoonModel'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
274
modelscope/models/cv/cartoon/loss.py
Normal file
274
modelscope/models/cv/cartoon/loss.py
Normal file
@@ -0,0 +1,274 @@
|
||||
'''
|
||||
CVPR 2020 submission, Paper ID 6791
|
||||
Source code for 'Learning to Cartoonize Using White-Box Cartoon Representations'
|
||||
'''
|
||||
|
||||
import os.path as osp
|
||||
|
||||
import numpy as np
|
||||
import scipy.stats as st
|
||||
import tensorflow as tf
|
||||
from joblib import Parallel, delayed
|
||||
from skimage import color, segmentation
|
||||
|
||||
from .network import disc_sn
|
||||
|
||||
VGG_MEAN = [103.939, 116.779, 123.68]
|
||||
|
||||
|
||||
class Vgg19:
|
||||
|
||||
def __init__(self, vgg19_npy_path=None):
|
||||
|
||||
self.data_dict = np.load(
|
||||
vgg19_npy_path, encoding='latin1', allow_pickle=True).item()
|
||||
|
||||
print('Finished loading vgg19.npy')
|
||||
|
||||
def build_conv4_4(self, rgb, include_fc=False):
|
||||
|
||||
rgb_scaled = (rgb + 1) * 127.5
|
||||
|
||||
blue, green, red = tf.split(
|
||||
axis=3, num_or_size_splits=3, value=rgb_scaled)
|
||||
bgr = tf.concat(
|
||||
axis=3,
|
||||
values=[
|
||||
blue - VGG_MEAN[0], green - VGG_MEAN[1], red - VGG_MEAN[2]
|
||||
])
|
||||
|
||||
self.conv1_1 = self.conv_layer(bgr, 'conv1_1')
|
||||
self.relu1_1 = tf.nn.relu(self.conv1_1)
|
||||
self.conv1_2 = self.conv_layer(self.relu1_1, 'conv1_2')
|
||||
self.relu1_2 = tf.nn.relu(self.conv1_2)
|
||||
self.pool1 = self.max_pool(self.relu1_2, 'pool1')
|
||||
|
||||
self.conv2_1 = self.conv_layer(self.pool1, 'conv2_1')
|
||||
self.relu2_1 = tf.nn.relu(self.conv2_1)
|
||||
self.conv2_2 = self.conv_layer(self.relu2_1, 'conv2_2')
|
||||
self.relu2_2 = tf.nn.relu(self.conv2_2)
|
||||
self.pool2 = self.max_pool(self.relu2_2, 'pool2')
|
||||
|
||||
self.conv3_1 = self.conv_layer(self.pool2, 'conv3_1')
|
||||
self.relu3_1 = tf.nn.relu(self.conv3_1)
|
||||
self.conv3_2 = self.conv_layer(self.relu3_1, 'conv3_2')
|
||||
self.relu3_2 = tf.nn.relu(self.conv3_2)
|
||||
self.conv3_3 = self.conv_layer(self.relu3_2, 'conv3_3')
|
||||
self.relu3_3 = tf.nn.relu(self.conv3_3)
|
||||
self.conv3_4 = self.conv_layer(self.relu3_3, 'conv3_4')
|
||||
self.relu3_4 = tf.nn.relu(self.conv3_4)
|
||||
self.pool3 = self.max_pool(self.relu3_4, 'pool3')
|
||||
|
||||
self.conv4_1 = self.conv_layer(self.pool3, 'conv4_1')
|
||||
self.relu4_1 = tf.nn.relu(self.conv4_1)
|
||||
self.conv4_2 = self.conv_layer(self.relu4_1, 'conv4_2')
|
||||
self.relu4_2 = tf.nn.relu(self.conv4_2)
|
||||
self.conv4_3 = self.conv_layer(self.relu4_2, 'conv4_3')
|
||||
self.relu4_3 = tf.nn.relu(self.conv4_3)
|
||||
self.conv4_4 = self.conv_layer(self.relu4_3, 'conv4_4')
|
||||
self.relu4_4 = tf.nn.relu(self.conv4_4)
|
||||
self.pool4 = self.max_pool(self.relu4_4, 'pool4')
|
||||
|
||||
return self.conv4_4
|
||||
|
||||
def max_pool(self, bottom, name):
|
||||
return tf.nn.max_pool(
|
||||
bottom,
|
||||
ksize=[1, 2, 2, 1],
|
||||
strides=[1, 2, 2, 1],
|
||||
padding='SAME',
|
||||
name=name)
|
||||
|
||||
def conv_layer(self, bottom, name):
|
||||
with tf.variable_scope(name):
|
||||
filt = self.get_conv_filter(name)
|
||||
conv = tf.nn.conv2d(bottom, filt, [1, 1, 1, 1], padding='SAME')
|
||||
conv_biases = self.get_bias(name)
|
||||
bias = tf.nn.bias_add(conv, conv_biases)
|
||||
return bias
|
||||
|
||||
def fc_layer(self, bottom, name):
|
||||
with tf.variable_scope(name):
|
||||
shape = bottom.get_shape().as_list()
|
||||
dim = 1
|
||||
for d in shape[1:]:
|
||||
dim *= d
|
||||
x = tf.reshape(bottom, [-1, dim])
|
||||
weights = self.get_fc_weight(name)
|
||||
biases = self.get_bias(name)
|
||||
fc = tf.nn.bias_add(tf.matmul(x, weights), biases)
|
||||
|
||||
return fc
|
||||
|
||||
def get_conv_filter(self, name):
|
||||
return tf.constant(self.data_dict[name][0], name='filter')
|
||||
|
||||
def get_bias(self, name):
|
||||
return tf.constant(self.data_dict[name][1], name='biases')
|
||||
|
||||
def get_fc_weight(self, name):
|
||||
return tf.constant(self.data_dict[name][0], name='weights')
|
||||
|
||||
|
||||
def content_loss(model_dir, input_photo, transfer_res, input_superpixel):
|
||||
vgg_model = Vgg19(osp.join(model_dir, 'vgg19.npy'))
|
||||
vgg_photo = vgg_model.build_conv4_4(input_photo)
|
||||
vgg_output = vgg_model.build_conv4_4(transfer_res)
|
||||
vgg_superpixel = vgg_model.build_conv4_4(input_superpixel)
|
||||
h, w, c = vgg_photo.get_shape().as_list()[1:]
|
||||
abs_photo = tf.losses.absolute_difference(vgg_photo, vgg_output)
|
||||
photo_loss = tf.reduce_mean(abs_photo) / (h * w * c)
|
||||
abs_superpixel = tf.losses.absolute_difference(vgg_superpixel, vgg_output)
|
||||
superpixel_loss = tf.reduce_mean(abs_superpixel) / (h * w * c)
|
||||
loss = photo_loss + superpixel_loss
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def style_loss(input_cartoon, output_cartoon):
|
||||
blur_fake = guided_filter(output_cartoon, output_cartoon, r=5, eps=2e-1)
|
||||
blur_cartoon = guided_filter(input_cartoon, input_cartoon, r=5, eps=2e-1)
|
||||
|
||||
gray_fake, gray_cartoon = color_shift(output_cartoon, input_cartoon)
|
||||
|
||||
d_loss_gray, g_loss_gray = lsgan_loss(
|
||||
disc_sn,
|
||||
gray_cartoon,
|
||||
gray_fake,
|
||||
scale=1,
|
||||
patch=True,
|
||||
name='disc_gray')
|
||||
d_loss_blur, g_loss_blur = lsgan_loss(
|
||||
disc_sn,
|
||||
blur_cartoon,
|
||||
blur_fake,
|
||||
scale=1,
|
||||
patch=True,
|
||||
name='disc_blur')
|
||||
sty_g_loss = (g_loss_blur) + g_loss_gray
|
||||
sty_d_loss = d_loss_blur + d_loss_gray
|
||||
|
||||
return sty_g_loss, sty_d_loss
|
||||
|
||||
|
||||
def gan_loss(discriminator,
|
||||
real,
|
||||
fake,
|
||||
scale=1,
|
||||
channel=32,
|
||||
patch=False,
|
||||
name='discriminator'):
|
||||
|
||||
real_logit = discriminator(
|
||||
real, scale, channel, name=name, patch=patch, reuse=False)
|
||||
fake_logit = discriminator(
|
||||
fake, scale, channel, name=name, patch=patch, reuse=True)
|
||||
|
||||
real_logit = tf.nn.sigmoid(real_logit)
|
||||
fake_logit = tf.nn.sigmoid(fake_logit)
|
||||
|
||||
g_loss_blur = -tf.reduce_mean(tf.log(fake_logit))
|
||||
d_loss_blur = -tf.reduce_mean(tf.log(real_logit) + tf.log(1. - fake_logit))
|
||||
|
||||
return d_loss_blur, g_loss_blur
|
||||
|
||||
|
||||
def lsgan_loss(discriminator,
|
||||
real,
|
||||
fake,
|
||||
scale=1,
|
||||
channel=32,
|
||||
patch=False,
|
||||
name='discriminator'):
|
||||
|
||||
real_logit = discriminator(
|
||||
real, scale, channel, name=name, patch=patch, reuse=False)
|
||||
fake_logit = discriminator(
|
||||
fake, scale, channel, name=name, patch=patch, reuse=True)
|
||||
|
||||
g_loss = tf.reduce_mean((fake_logit - 1)**2)
|
||||
d_loss = 0.5 * (
|
||||
tf.reduce_mean((real_logit - 1)**2) + tf.reduce_mean(fake_logit**2))
|
||||
|
||||
return d_loss, g_loss
|
||||
|
||||
|
||||
def total_variation_loss(image, k_size=1):
|
||||
h, w = image.get_shape().as_list()[1:3]
|
||||
tv_h = tf.reduce_mean(
|
||||
(image[:, k_size:, :, :] - image[:, :h - k_size, :, :])**2)
|
||||
tv_w = tf.reduce_mean(
|
||||
(image[:, :, k_size:, :] - image[:, :, :w - k_size, :])**2)
|
||||
tv_loss = (tv_h + tv_w) / (3 * h * w)
|
||||
return tv_loss
|
||||
|
||||
|
||||
def guided_filter(x, y, r, eps=1e-2):
|
||||
x_shape = tf.shape(x)
|
||||
N = tf_box_filter(
|
||||
tf.ones((1, x_shape[1], x_shape[2], 1), dtype=x.dtype), r)
|
||||
|
||||
mean_x = tf_box_filter(x, r) / N
|
||||
mean_y = tf_box_filter(y, r) / N
|
||||
cov_xy = tf_box_filter(x * y, r) / N - mean_x * mean_y
|
||||
var_x = tf_box_filter(x * x, r) / N - mean_x * mean_x
|
||||
|
||||
A = cov_xy / (var_x + eps)
|
||||
b = mean_y - A * mean_x
|
||||
|
||||
mean_A = tf_box_filter(A, r) / N
|
||||
mean_b = tf_box_filter(b, r) / N
|
||||
|
||||
output = mean_A * x + mean_b
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def color_shift(image1, image2, mode='uniform'):
|
||||
b1, g1, r1 = tf.split(image1, num_or_size_splits=3, axis=3)
|
||||
b2, g2, r2 = tf.split(image2, num_or_size_splits=3, axis=3)
|
||||
if mode == 'normal':
|
||||
b_weight = tf.random.normal(shape=[1], mean=0.114, stddev=0.1)
|
||||
g_weight = np.random.normal(shape=[1], mean=0.587, stddev=0.1)
|
||||
r_weight = np.random.normal(shape=[1], mean=0.299, stddev=0.1)
|
||||
elif mode == 'uniform':
|
||||
b_weight = tf.random.uniform(shape=[1], minval=0.014, maxval=0.214)
|
||||
g_weight = tf.random.uniform(shape=[1], minval=0.487, maxval=0.687)
|
||||
r_weight = tf.random.uniform(shape=[1], minval=0.199, maxval=0.399)
|
||||
output1 = (b_weight * b1 + g_weight * g1 + r_weight * r1) / (
|
||||
b_weight + g_weight + r_weight)
|
||||
output2 = (b_weight * b2 + g_weight * g2 + r_weight * r2) / (
|
||||
b_weight + g_weight + r_weight)
|
||||
return output1, output2
|
||||
|
||||
|
||||
def simple_superpixel(batch_image, seg_num=200):
|
||||
|
||||
def process_slic(image):
|
||||
seg_label = segmentation.slic(
|
||||
image,
|
||||
n_segments=seg_num,
|
||||
sigma=1,
|
||||
compactness=10,
|
||||
convert2lab=True,
|
||||
start_label=1)
|
||||
image = color.label2rgb(seg_label, image, kind='avg', bg_label=0)
|
||||
return image
|
||||
|
||||
num_job = np.shape(batch_image)[0]
|
||||
batch_out = Parallel(n_jobs=num_job)(
|
||||
delayed(process_slic)(image) for image in batch_image)
|
||||
return np.array(batch_out)
|
||||
|
||||
|
||||
def tf_box_filter(x, r):
|
||||
ch = x.get_shape().as_list()[-1]
|
||||
weight = 1 / ((2 * r + 1)**2)
|
||||
box_kernel = weight * np.ones((2 * r + 1, 2 * r + 1, ch, 1))
|
||||
box_kernel = np.array(box_kernel).astype(np.float32)
|
||||
output = tf.nn.depthwise_conv2d(x, box_kernel, [1, 1, 1, 1], 'SAME')
|
||||
return output
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
||||
64
modelscope/models/cv/cartoon/model_tf.py
Normal file
64
modelscope/models/cv/cartoon/model_tf.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from modelscope.models.base import Model, Tensor
|
||||
from .loss import content_loss, guided_filter, style_loss, total_variation_loss
|
||||
from .network import unet_generator
|
||||
|
||||
|
||||
class CartoonModel(Model):
|
||||
|
||||
def __init__(self, model_dir, *args, **kwargs):
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
self.model_dir = model_dir
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_photo: Dict[str, Tensor],
|
||||
input_cartoon: Dict[str, Tensor] = None,
|
||||
input_superpixel: Dict[str, Tensor] = None) -> Dict[str, Tensor]:
|
||||
"""return the result by the model
|
||||
|
||||
Args:
|
||||
input_photo: the preprocessed input photo image
|
||||
input_cartoon: the preprocessed input cartoon image
|
||||
input_superpixel: the computed input superpixel image
|
||||
|
||||
Returns:
|
||||
output_dict: output dict of target ids
|
||||
"""
|
||||
if input_cartoon is None:
|
||||
output = unet_generator(input_photo)
|
||||
output_cartoon = guided_filter(input_photo, output, r=1)
|
||||
return {'output_cartoon': output_cartoon}
|
||||
else:
|
||||
output = unet_generator(input_photo)
|
||||
output_cartoon = guided_filter(input_photo, output, r=1)
|
||||
|
||||
con_loss = content_loss(self.model_dir, input_photo,
|
||||
output_cartoon, input_superpixel)
|
||||
sty_g_loss, sty_d_loss = style_loss(input_cartoon, output_cartoon)
|
||||
tv_loss = total_variation_loss(output_cartoon)
|
||||
|
||||
g_loss = 1e-1 * sty_g_loss + 2e2 * con_loss + 1e4 * tv_loss
|
||||
d_loss = sty_d_loss
|
||||
|
||||
return {
|
||||
'output_cartoon': output_cartoon,
|
||||
'g_loss': g_loss,
|
||||
'd_loss': d_loss,
|
||||
}
|
||||
|
||||
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
"""
|
||||
Run the forward pass for a model.
|
||||
|
||||
Args:
|
||||
input (Dict[str, Tensor]): the dict of the model inputs for the forward method
|
||||
|
||||
Returns:
|
||||
Dict[str, Tensor]: output from the model forward pass
|
||||
"""
|
||||
146
modelscope/models/cv/cartoon/network.py
Normal file
146
modelscope/models/cv/cartoon/network.py
Normal file
@@ -0,0 +1,146 @@
|
||||
import tensorflow as tf
|
||||
import tensorflow.contrib.slim as slim
|
||||
|
||||
|
||||
def resblock(inputs, out_channel=32, name='resblock'):
|
||||
with tf.variable_scope(name):
|
||||
x = slim.convolution2d(
|
||||
inputs, out_channel, [3, 3], activation_fn=None, scope='conv1')
|
||||
|
||||
x = tf.nn.leaky_relu(x)
|
||||
x = slim.convolution2d(
|
||||
x, out_channel, [3, 3], activation_fn=None, scope='conv2')
|
||||
|
||||
return x + inputs
|
||||
|
||||
|
||||
def spectral_norm(w, iteration=1):
|
||||
w_shape = w.shape.as_list()
|
||||
w = tf.reshape(w, [-1, w_shape[-1]])
|
||||
|
||||
u = tf.get_variable(
|
||||
'u', [1, w_shape[-1]],
|
||||
initializer=tf.random_normal_initializer(),
|
||||
trainable=False)
|
||||
|
||||
u_hat = u
|
||||
v_hat = None
|
||||
for i in range(iteration):
|
||||
"""
|
||||
power iteration
|
||||
Usually iteration = 1 will be enough
|
||||
"""
|
||||
v_ = tf.matmul(u_hat, tf.transpose(w))
|
||||
v_hat = tf.nn.l2_normalize(v_)
|
||||
|
||||
u_ = tf.matmul(v_hat, w)
|
||||
u_hat = tf.nn.l2_normalize(u_)
|
||||
|
||||
u_hat = tf.stop_gradient(u_hat)
|
||||
v_hat = tf.stop_gradient(v_hat)
|
||||
|
||||
sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat))
|
||||
|
||||
with tf.control_dependencies([u.assign(u_hat)]):
|
||||
w_norm = w / sigma
|
||||
w_norm = tf.reshape(w_norm, w_shape)
|
||||
|
||||
return w_norm
|
||||
|
||||
|
||||
def conv_spectral_norm(x, channel, k_size, stride=1, name='conv_snorm'):
|
||||
with tf.variable_scope(name):
|
||||
w = tf.get_variable(
|
||||
'kernel', shape=[k_size[0], k_size[1],
|
||||
x.get_shape()[-1], channel])
|
||||
b = tf.get_variable(
|
||||
'bias', [channel], initializer=tf.constant_initializer(0.0))
|
||||
|
||||
x = tf.nn.conv2d(
|
||||
input=x,
|
||||
filter=spectral_norm(w),
|
||||
strides=[1, stride, stride, 1],
|
||||
padding='SAME') + b
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def unet_generator(inputs,
|
||||
channel=32,
|
||||
num_blocks=4,
|
||||
name='generator',
|
||||
reuse=False):
|
||||
with tf.variable_scope(name, reuse=reuse):
|
||||
|
||||
x0 = slim.convolution2d(inputs, channel, [7, 7], activation_fn=None)
|
||||
x0 = tf.nn.leaky_relu(x0)
|
||||
|
||||
x1 = slim.convolution2d(
|
||||
x0, channel, [3, 3], stride=2, activation_fn=None)
|
||||
x1 = tf.nn.leaky_relu(x1)
|
||||
x1 = slim.convolution2d(x1, channel * 2, [3, 3], activation_fn=None)
|
||||
x1 = tf.nn.leaky_relu(x1)
|
||||
|
||||
x2 = slim.convolution2d(
|
||||
x1, channel * 2, [3, 3], stride=2, activation_fn=None)
|
||||
x2 = tf.nn.leaky_relu(x2)
|
||||
x2 = slim.convolution2d(x2, channel * 4, [3, 3], activation_fn=None)
|
||||
x2 = tf.nn.leaky_relu(x2)
|
||||
|
||||
for idx in range(num_blocks):
|
||||
x2 = resblock(
|
||||
x2, out_channel=channel * 4, name='block_{}'.format(idx))
|
||||
|
||||
x2 = slim.convolution2d(x2, channel * 2, [3, 3], activation_fn=None)
|
||||
x2 = tf.nn.leaky_relu(x2)
|
||||
|
||||
h1, w1 = tf.shape(x2)[1], tf.shape(x2)[2]
|
||||
x3 = tf.image.resize_bilinear(x2, (h1 * 2, w1 * 2))
|
||||
x3 = slim.convolution2d(
|
||||
x3 + x1, channel * 2, [3, 3], activation_fn=None)
|
||||
x3 = tf.nn.leaky_relu(x3)
|
||||
x3 = slim.convolution2d(x3, channel, [3, 3], activation_fn=None)
|
||||
x3 = tf.nn.leaky_relu(x3)
|
||||
|
||||
h2, w2 = tf.shape(x3)[1], tf.shape(x3)[2]
|
||||
x4 = tf.image.resize_bilinear(x3, (h2 * 2, w2 * 2))
|
||||
x4 = slim.convolution2d(x4 + x0, channel, [3, 3], activation_fn=None)
|
||||
x4 = tf.nn.leaky_relu(x4)
|
||||
x4 = slim.convolution2d(x4, 3, [7, 7], activation_fn=None)
|
||||
|
||||
# x4 = tf.clip_by_value(x4, -1, 1)
|
||||
return x4
|
||||
|
||||
|
||||
def disc_sn(x,
|
||||
scale=1,
|
||||
channel=32,
|
||||
patch=True,
|
||||
name='discriminator',
|
||||
reuse=False):
|
||||
with tf.variable_scope(name, reuse=reuse):
|
||||
|
||||
for idx in range(3):
|
||||
x = conv_spectral_norm(
|
||||
x,
|
||||
channel * 2**idx, [3, 3],
|
||||
stride=2,
|
||||
name='conv{}_1'.format(idx))
|
||||
x = tf.nn.leaky_relu(x)
|
||||
|
||||
x = conv_spectral_norm(
|
||||
x, channel * 2**idx, [3, 3], name='conv{}_2'.format(idx))
|
||||
x = tf.nn.leaky_relu(x)
|
||||
|
||||
if patch is True:
|
||||
x = conv_spectral_norm(x, 1, [1, 1], name='conv_out')
|
||||
|
||||
else:
|
||||
x = tf.reduce_mean(x, axis=[1, 2])
|
||||
x = slim.fully_connected(x, 1, activation_fn=None)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
||||
@@ -1,9 +1,11 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import random
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def resize_size(image, size=720):
|
||||
@@ -83,11 +85,87 @@ def find_pupil(landmarks, np_img):
|
||||
return xm + xmin, ym + ymin
|
||||
|
||||
|
||||
def next_batch(filename_list, batch_size, fineSize=256):
|
||||
idx = np.arange(0, len(filename_list))
|
||||
np.random.shuffle(idx)
|
||||
idx = idx[:batch_size]
|
||||
batch_data = []
|
||||
for i in range(batch_size):
|
||||
image = cv2.imread(filename_list[idx[i]])
|
||||
h, w, c = image.shape
|
||||
rw = random.randint(0, w - fineSize)
|
||||
rh = random.randint(0, h - fineSize)
|
||||
image = image[rh:rh + fineSize, rw:rw + fineSize, :]
|
||||
image = image.astype(np.float32) / 127.5 - 1
|
||||
batch_data.append(image)
|
||||
|
||||
return np.asarray(batch_data)
|
||||
|
||||
|
||||
def read_image(image_path, IMAGE_SIZE=256):
|
||||
image = tf.io.read_file(image_path)
|
||||
image = tf.image.decode_image(image, channels=3)
|
||||
image = tf.image.convert_image_dtype(image, tf.float32)
|
||||
image.set_shape([None, None, 3])
|
||||
image = tf.image.resize(images=image, size=[IMAGE_SIZE, IMAGE_SIZE])
|
||||
image = image[..., ::-1]
|
||||
# image = image / 127.5 - 1
|
||||
image = (image - 0.5) * 2
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def load_data(photo_list):
|
||||
photo = read_image(photo_list)
|
||||
return photo
|
||||
|
||||
|
||||
def tf_data_loader(image_list, batch_size):
|
||||
dataset = tf.data.Dataset.from_tensor_slices((image_list))
|
||||
dataset = dataset.shuffle(len(image_list))
|
||||
dataset = dataset.map(load_data, num_parallel_calls=4)
|
||||
dataset = dataset.batch(batch_size)
|
||||
dataset = dataset.prefetch(1)
|
||||
return dataset
|
||||
|
||||
|
||||
def write_batch_image(image, save_dir, name, n):
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
|
||||
fused_dir = os.path.join(save_dir, name)
|
||||
fused_image = [0] * n
|
||||
for i in range(n):
|
||||
fused_image[i] = []
|
||||
for j in range(n):
|
||||
k = i * n + j
|
||||
image[k] = (image[k] + 1) * 127.5
|
||||
image[k] = np.clip(image[k], 0, 255)
|
||||
fused_image[i].append(image[k])
|
||||
fused_image[i] = np.hstack(fused_image[i])
|
||||
fused_image = np.vstack(fused_image)
|
||||
cv2.imwrite(fused_dir, fused_image.astype(np.uint8))
|
||||
|
||||
|
||||
def grid_batch_image(image, n):
|
||||
fused_image = [0] * n
|
||||
for i in range(n):
|
||||
fused_image[i] = []
|
||||
for j in range(n):
|
||||
k = i * n + j
|
||||
image[k] = (image[k] + 1) * 127.5
|
||||
image[k] = np.clip(image[k], 0, 255)
|
||||
fused_image[i].append(image[k])
|
||||
fused_image[i] = np.hstack(fused_image[i])
|
||||
fused_image = np.vstack(fused_image)
|
||||
return fused_image
|
||||
|
||||
|
||||
def all_file(file_dir):
|
||||
L = []
|
||||
for root, dirs, files in os.walk(file_dir):
|
||||
for file in files:
|
||||
extend = os.path.splitext(file)[1]
|
||||
if extend == '.png' or extend == '.jpg' or extend == '.jpeg':
|
||||
if extend == '.png' or extend == '.jpg' or extend == '.jpeg' or extend == '.JPG':
|
||||
L.append(os.path.join(root, file))
|
||||
return L
|
||||
|
||||
@@ -11,6 +11,7 @@ if TYPE_CHECKING:
|
||||
from .image_inpainting_trainer import ImageInpaintingTrainer
|
||||
from .referring_video_object_segmentation_trainer import ReferringVideoObjectSegmentationTrainer
|
||||
from .image_defrcn_fewshot_detection_trainer import ImageDefrcnFewshotTrainer
|
||||
from .cartoon_translation_trainer import CartoonTranslationTrainer
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
@@ -23,7 +24,8 @@ else:
|
||||
'referring_video_object_segmentation_trainer':
|
||||
['ReferringVideoObjectSegmentationTrainer'],
|
||||
'image_defrcn_fewshot_detection_trainer':
|
||||
['ImageDefrcnFewshotTrainer']
|
||||
['ImageDefrcnFewshotTrainer'],
|
||||
'cartoon_translation_trainer': ['CartoonTranslationTrainer']
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
241
modelscope/trainers/cv/cartoon_translation_trainer.py
Normal file
241
modelscope/trainers/cv/cartoon_translation_trainer.py
Normal file
@@ -0,0 +1,241 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import os.path as osp
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from packaging import version
|
||||
from tqdm import tqdm
|
||||
|
||||
from modelscope.models.cv.cartoon import (CartoonModel, all_file,
|
||||
simple_superpixel, tf_data_loader,
|
||||
write_batch_image)
|
||||
from modelscope.trainers.base import BaseTrainer
|
||||
from modelscope.trainers.builder import TRAINERS
|
||||
from modelscope.utils.constant import ModelFile
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
if version.parse(tf.__version__) < version.parse('2'):
|
||||
pass
|
||||
else:
|
||||
logger.info(
|
||||
f'TensorFlow version {_tf_version} found, TF2.x is not supported by CartoonTranslationTrainer.'
|
||||
)
|
||||
|
||||
|
||||
@TRAINERS.register_module(module_name=r'cartoon-translation')
|
||||
class CartoonTranslationTrainer(BaseTrainer):
|
||||
|
||||
def __init__(self,
|
||||
model: str,
|
||||
cfg_file: str = None,
|
||||
work_dir=None,
|
||||
photo=None,
|
||||
cartoon=None,
|
||||
max_steps=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
"""
|
||||
Args:
|
||||
model: the model_id of trained model
|
||||
cfg_file: the path of configuration file
|
||||
work_dir: the path to save training results
|
||||
photo: the path of photo images for training
|
||||
cartoon: the path of cartoon images for training
|
||||
max_steps: the number of total iteration for training
|
||||
Returns:
|
||||
initialized trainer: object of CartoonTranslationTrainer
|
||||
"""
|
||||
model = self.get_or_download_model_dir(model)
|
||||
tf.reset_default_graph()
|
||||
|
||||
self.model_dir = model
|
||||
self.model_path = osp.join(model, ModelFile.TF_CHECKPOINT_FOLDER)
|
||||
if cfg_file is None:
|
||||
cfg_file = osp.join(model, ModelFile.CONFIGURATION)
|
||||
|
||||
super().__init__(cfg_file)
|
||||
|
||||
self.params = {}
|
||||
self._override_params_from_file()
|
||||
if work_dir is not None:
|
||||
self.params['work_dir'] = work_dir
|
||||
if photo is not None:
|
||||
self.params['photo'] = photo
|
||||
if cartoon is not None:
|
||||
self.params['cartoon'] = cartoon
|
||||
if max_steps is not None:
|
||||
self.params['max_steps'] = max_steps
|
||||
|
||||
if not os.path.exists(self.params['work_dir']):
|
||||
os.makedirs(self.params['work_dir'])
|
||||
|
||||
self.face_photo_list = all_file(self.params['photo'])
|
||||
self.face_cartoon_list = all_file(self.params['cartoon'])
|
||||
|
||||
tf_config = tf.ConfigProto(allow_soft_placement=True)
|
||||
tf_config.gpu_options.allow_growth = True
|
||||
self._session = tf.Session(config=tf_config)
|
||||
|
||||
self.input_photo = tf.placeholder(tf.float32, [
|
||||
self.params['batch_size'], self.params['patch_size'],
|
||||
self.params['patch_size'], 3
|
||||
])
|
||||
self.input_superpixel = tf.placeholder(tf.float32, [
|
||||
self.params['batch_size'], self.params['patch_size'],
|
||||
self.params['patch_size'], 3
|
||||
])
|
||||
self.input_cartoon = tf.placeholder(tf.float32, [
|
||||
self.params['batch_size'], self.params['patch_size'],
|
||||
self.params['patch_size'], 3
|
||||
])
|
||||
|
||||
self.model = CartoonModel(self.model_dir)
|
||||
output = self.model(self.input_photo, self.input_cartoon,
|
||||
self.input_superpixel)
|
||||
self.output_cartoon = output['output_cartoon']
|
||||
self.g_loss = output['g_loss']
|
||||
self.d_loss = output['d_loss']
|
||||
|
||||
tf.summary.scalar('g_loss', self.g_loss)
|
||||
tf.summary.scalar('d_loss', self.d_loss)
|
||||
|
||||
self.train_writer = tf.summary.FileWriter(self.params['work_dir']
|
||||
+ '/train_log')
|
||||
self.summary_op = tf.summary.merge_all()
|
||||
|
||||
all_vars = tf.trainable_variables()
|
||||
gene_vars = [var for var in all_vars if 'gene' in var.name]
|
||||
disc_vars = [var for var in all_vars if 'disc' in var.name]
|
||||
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
|
||||
with tf.control_dependencies(update_ops):
|
||||
self.g_optim = tf.train.AdamOptimizer(self.params['adv_train_lr'], beta1=0.5, beta2=0.99) \
|
||||
.minimize(self.g_loss, var_list=gene_vars)
|
||||
self.d_optim = tf.train.AdamOptimizer(self.params['adv_train_lr'], beta1=0.5, beta2=0.99) \
|
||||
.minimize(self.d_loss, var_list=disc_vars)
|
||||
|
||||
self.saver = tf.train.Saver(max_to_keep=1000)
|
||||
with self._session.as_default() as sess:
|
||||
sess.run(tf.global_variables_initializer())
|
||||
if self.params['resume_epoch'] != 0:
|
||||
logger.info(f'loading model from {self.model_path}')
|
||||
self.saver.restore(
|
||||
sess,
|
||||
osp.join(self.model_path,
|
||||
'model-' + str(self.params['resume_epoch'])))
|
||||
|
||||
def _override_params_from_file(self):
|
||||
|
||||
self.params['photo'] = self.cfg['train']['photo']
|
||||
self.params['cartoon'] = self.cfg['train']['cartoon']
|
||||
self.params['patch_size'] = self.cfg['train']['patch_size']
|
||||
self.params['work_dir'] = self.cfg['train']['work_dir']
|
||||
self.params['batch_size'] = self.cfg['train']['batch_size']
|
||||
self.params['adv_train_lr'] = self.cfg['train']['adv_train_lr']
|
||||
self.params['max_steps'] = self.cfg['train']['max_steps']
|
||||
self.params['logging_interval'] = self.cfg['train']['logging_interval']
|
||||
self.params['ckpt_period_interval'] = self.cfg['train'][
|
||||
'ckpt_period_interval']
|
||||
self.params['resume_epoch'] = self.cfg['train']['resume_epoch']
|
||||
self.params['num_gpus'] = self.cfg['train']['num_gpus']
|
||||
|
||||
def train(self, *args, **kwargs):
|
||||
logger.info('Begin local cartoon translator training')
|
||||
|
||||
photo_ds = tf_data_loader(self.face_photo_list,
|
||||
self.params['batch_size'])
|
||||
cartoon_ds = tf_data_loader(self.face_cartoon_list,
|
||||
self.params['batch_size'])
|
||||
photo_iterator = photo_ds.make_initializable_iterator()
|
||||
cartoon_iterator = cartoon_ds.make_initializable_iterator()
|
||||
photo_next = photo_iterator.get_next()
|
||||
cartoon_next = cartoon_iterator.get_next()
|
||||
|
||||
device = 'gpu:0' if tf.test.is_gpu_available else 'cpu:0'
|
||||
with tf.device(device):
|
||||
|
||||
for max_steps in tqdm(range(self.params['max_steps'])):
|
||||
|
||||
self._session.run(photo_iterator.initializer)
|
||||
self._session.run(cartoon_iterator.initializer)
|
||||
|
||||
photo_batch, cartoon_batch = self._session.run(
|
||||
[photo_next, cartoon_next])
|
||||
|
||||
transfer_res = self._session.run(
|
||||
self.output_cartoon,
|
||||
feed_dict={self.input_photo: photo_batch})
|
||||
|
||||
input_superpixel = simple_superpixel(transfer_res, seg_num=200)
|
||||
g_loss, _ = self._session.run(
|
||||
[self.g_loss, self.g_optim],
|
||||
feed_dict={
|
||||
self.input_photo: photo_batch,
|
||||
self.input_superpixel: input_superpixel,
|
||||
self.input_cartoon: cartoon_batch
|
||||
})
|
||||
|
||||
d_loss, _, train_info = self._session.run(
|
||||
[self.d_loss, self.d_optim, self.summary_op],
|
||||
feed_dict={
|
||||
self.input_photo: photo_batch,
|
||||
self.input_superpixel: input_superpixel,
|
||||
self.input_cartoon: cartoon_batch
|
||||
})
|
||||
|
||||
self.train_writer.add_summary(train_info, max_steps)
|
||||
|
||||
if np.mod(max_steps + 1, self.params['logging_interval']
|
||||
) == 0 or max_steps == 0:
|
||||
|
||||
logger.info(
|
||||
f'Iter: {max_steps}, d_loss: {d_loss}, g_loss: {g_loss}'
|
||||
)
|
||||
|
||||
if np.mod(max_steps + 1,
|
||||
self.params['ckpt_period_interval']
|
||||
) == 0 or max_steps == 0:
|
||||
self.saver.save(
|
||||
self._session,
|
||||
self.params['work_dir'] + '/saved_models/model',
|
||||
write_meta_graph=False,
|
||||
global_step=max_steps)
|
||||
|
||||
result_face = self._session.run(
|
||||
self.output_cartoon,
|
||||
feed_dict={
|
||||
self.input_photo: photo_batch,
|
||||
self.input_superpixel: photo_batch,
|
||||
self.input_cartoon: cartoon_batch
|
||||
})
|
||||
|
||||
write_batch_image(
|
||||
result_face, self.params['work_dir'] + '/images',
|
||||
str('%8d' % max_steps) + '_face_result.jpg', 4)
|
||||
write_batch_image(
|
||||
photo_batch, self.params['work_dir'] + '/images',
|
||||
str('%8d' % max_steps) + '_face_photo.jpg', 4)
|
||||
|
||||
def evaluate(self,
|
||||
checkpoint_path: Optional[str] = None,
|
||||
*args,
|
||||
**kwargs) -> Dict[str, float]:
|
||||
"""evaluate a dataset
|
||||
|
||||
evaluate a dataset via a specific model from the `checkpoint_path` path, if the `checkpoint_path`
|
||||
does not exist, read from the config file.
|
||||
|
||||
Args:
|
||||
checkpoint_path (Optional[str], optional): the model path. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Dict[str, float]: the results about the evaluation
|
||||
Example:
|
||||
{"accuracy": 0.5091743119266054, "f1": 0.673780487804878}
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
'evaluate is not supported by CartoonTranslationTrainer')
|
||||
@@ -59,6 +59,7 @@ isolated: # test cases that may require excessive anmount of GPU memory or run
|
||||
- test_video_deinterlace.py
|
||||
- test_image_inpainting_sdv2.py
|
||||
- test_bad_image_detecting.py
|
||||
- test_image_portrait_stylization_trainer.py
|
||||
- test_controllable_image_generation.py
|
||||
|
||||
envs:
|
||||
@@ -80,3 +81,4 @@ envs:
|
||||
- test_person_image_cartoon.py
|
||||
- test_skin_retouching.py
|
||||
- test_image_style_transfer.py
|
||||
- test_image_portrait_stylization_trainer.py
|
||||
|
||||
61
tests/trainers/test_image_portrait_stylization_trainer.py
Normal file
61
tests/trainers/test_image_portrait_stylization_trainer.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import cv2
|
||||
|
||||
from modelscope.exporters.cv import CartoonTranslationExporter
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.trainers.cv import CartoonTranslationTrainer
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class TestImagePortraitStylizationTrainer(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.task = Tasks.image_portrait_stylization
|
||||
self.test_image = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/image_cartoon.png'
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_model_name(self):
|
||||
model_id = 'damo/cv_unet_person-image-cartoon_compound-models'
|
||||
|
||||
data_dir = MsDataset.load(
|
||||
'dctnet_train_clipart_mini_ms',
|
||||
namespace='menyifang',
|
||||
split='train').config_kwargs['split_config']['train']
|
||||
|
||||
data_photo = os.path.join(data_dir, 'face_photo')
|
||||
data_cartoon = os.path.join(data_dir, 'face_cartoon')
|
||||
work_dir = 'exp_localtoon'
|
||||
max_steps = 10
|
||||
trainer = CartoonTranslationTrainer(
|
||||
model=model_id,
|
||||
work_dir=work_dir,
|
||||
photo=data_photo,
|
||||
cartoon=data_cartoon,
|
||||
max_steps=max_steps)
|
||||
trainer.train()
|
||||
|
||||
ckpt_path = os.path.join(work_dir, 'saved_models', 'model-' + str(0))
|
||||
pb_path = os.path.join(trainer.model_dir, 'cartoon_h.pb')
|
||||
exporter = CartoonTranslationExporter()
|
||||
exporter.export_frozen_graph_def(
|
||||
ckpt_path=ckpt_path, frozen_graph_path=pb_path)
|
||||
|
||||
self.pipeline_person_image_cartoon(trainer.model_dir)
|
||||
|
||||
def pipeline_person_image_cartoon(self, model_dir):
|
||||
pipeline_cartoon = pipeline(task=self.task, model=model_dir)
|
||||
result = pipeline_cartoon(input=self.test_image)
|
||||
if result is not None:
|
||||
cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG])
|
||||
print(f'Output written to {os.path.abspath("result.png")}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user