[to #42322933] add fintune support for cartoon task

人像卡通化模型增加训练支持

 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11675597

* add fintune support for cartoon
This commit is contained in:
myf272609
2023-02-28 17:01:34 +08:00
committed by wenmeng.zwm
parent 61b1ee024a
commit e63593f3bb
12 changed files with 961 additions and 4 deletions

View File

@@ -5,6 +5,7 @@ from .base import Exporter
from .builder import build_exporter
if is_tf_available():
from .cv import CartoonTranslationExporter
from .nlp import CsanmtForTranslationExporter
from .tf_model_exporter import TfModelExporter
if is_torch_available():

View File

@@ -0,0 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from modelscope.utils.import_utils import is_tf_available
if is_tf_available():
from .cartoon_translation_exporter import CartoonTranslationExporter

View File

@@ -0,0 +1,72 @@
import os
from typing import Any, Dict
import tensorflow as tf
from packaging import version
from modelscope.exporters.builder import EXPORTERS
from modelscope.exporters.tf_model_exporter import TfModelExporter
from modelscope.models.cv.cartoon import CartoonModel
from modelscope.utils.logger import get_logger
logger = get_logger(__name__)
if version.parse(tf.__version__) < version.parse('2'):
pass
else:
logger.info(
f'TensorFlow version {_tf_version} found, TF2.x is not supported by CartoonTranslationExporter.'
)
tf.logging.set_verbosity(tf.logging.INFO)
@EXPORTERS.register_module(module_name=r'cartoon-translation')
class CartoonTranslationExporter(TfModelExporter):
def __init__(self, model=None):
super().__init__(model)
def export_frozen_graph_def(self, ckpt_path: str, frozen_graph_path: str,
**kwargs):
tf.get_variable_scope().reuse_variables()
input = tf.placeholder(tf.float32, [None, None, 3], name='input_image')
input = input[:, :, :][tf.newaxis]
input = input / 127.5 - 1.0
model = CartoonModel(model_dir='')
output = model(input)
final_out = output['output_cartoon'][0]
final_out = tf.clip_by_value(final_out, -0.999999, 0.999999)
final_out = (final_out + 1.0) * 127.5
final_out = tf.cast(final_out, tf.uint8, name='output_image')
all_vars = tf.trainable_variables()
gene_vars = [var for var in all_vars if 'generator' in var.name]
saver = tf.train.Saver(var_list=gene_vars)
init = tf.global_variables_initializer()
config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
sess.run(init)
saver.restore(sess, ckpt_path)
frozen_graph_def = tf.graph_util.convert_variables_to_constants(
sess, sess.graph_def, output_node_names=['output_image'])
with open(frozen_graph_path, 'wb') as f:
f.write(frozen_graph_def.SerializeToString())
print('freeze done')
return {'model': frozen_graph_path}
def export_saved_model(self, output_dir: str, **kwargs):
raise NotImplementedError(
'Exporting saved model is not supported by CartoonTranslationExporter currently.'
)
def export_onnx(self, output_dir: str, **kwargs):
raise NotImplementedError(
'Exporting onnx model is not supported by CartoonTranslationExporter currently.'
)

View File

@@ -7,14 +7,24 @@ if TYPE_CHECKING:
from .facelib.facer import FaceAna
from .mtcnn_pytorch.src.align_trans import (get_reference_facial_points,
warp_and_crop_face)
from .utils import (get_f5p, padTo16x, resize_size)
from .utils import (get_f5p, padTo16x, resize_size, all_file,
tf_data_loader, write_batch_image)
from .network import disc_sn
from .loss import simple_superpixel
from .model_tf import CartoonModel
else:
_import_structure = {
'facelib.facer': ['FaceAna'],
'mtcnn_pytorch.src.align_trans':
['get_reference_facial_points', 'warp_and_crop_face'],
'utils': ['get_f5p', 'padTo16x', 'resize_size']
'utils': [
'get_f5p', 'padTo16x', 'resize_size', 'all_file', 'tf_data_loader',
'write_batch_image'
],
'network': ['disc_sn'],
'loss': ['simple_superpixel'],
'model_tf': ['CartoonModel'],
}
import sys

View File

@@ -0,0 +1,274 @@
'''
CVPR 2020 submission, Paper ID 6791
Source code for 'Learning to Cartoonize Using White-Box Cartoon Representations'
'''
import os.path as osp
import numpy as np
import scipy.stats as st
import tensorflow as tf
from joblib import Parallel, delayed
from skimage import color, segmentation
from .network import disc_sn
VGG_MEAN = [103.939, 116.779, 123.68]
class Vgg19:
def __init__(self, vgg19_npy_path=None):
self.data_dict = np.load(
vgg19_npy_path, encoding='latin1', allow_pickle=True).item()
print('Finished loading vgg19.npy')
def build_conv4_4(self, rgb, include_fc=False):
rgb_scaled = (rgb + 1) * 127.5
blue, green, red = tf.split(
axis=3, num_or_size_splits=3, value=rgb_scaled)
bgr = tf.concat(
axis=3,
values=[
blue - VGG_MEAN[0], green - VGG_MEAN[1], red - VGG_MEAN[2]
])
self.conv1_1 = self.conv_layer(bgr, 'conv1_1')
self.relu1_1 = tf.nn.relu(self.conv1_1)
self.conv1_2 = self.conv_layer(self.relu1_1, 'conv1_2')
self.relu1_2 = tf.nn.relu(self.conv1_2)
self.pool1 = self.max_pool(self.relu1_2, 'pool1')
self.conv2_1 = self.conv_layer(self.pool1, 'conv2_1')
self.relu2_1 = tf.nn.relu(self.conv2_1)
self.conv2_2 = self.conv_layer(self.relu2_1, 'conv2_2')
self.relu2_2 = tf.nn.relu(self.conv2_2)
self.pool2 = self.max_pool(self.relu2_2, 'pool2')
self.conv3_1 = self.conv_layer(self.pool2, 'conv3_1')
self.relu3_1 = tf.nn.relu(self.conv3_1)
self.conv3_2 = self.conv_layer(self.relu3_1, 'conv3_2')
self.relu3_2 = tf.nn.relu(self.conv3_2)
self.conv3_3 = self.conv_layer(self.relu3_2, 'conv3_3')
self.relu3_3 = tf.nn.relu(self.conv3_3)
self.conv3_4 = self.conv_layer(self.relu3_3, 'conv3_4')
self.relu3_4 = tf.nn.relu(self.conv3_4)
self.pool3 = self.max_pool(self.relu3_4, 'pool3')
self.conv4_1 = self.conv_layer(self.pool3, 'conv4_1')
self.relu4_1 = tf.nn.relu(self.conv4_1)
self.conv4_2 = self.conv_layer(self.relu4_1, 'conv4_2')
self.relu4_2 = tf.nn.relu(self.conv4_2)
self.conv4_3 = self.conv_layer(self.relu4_2, 'conv4_3')
self.relu4_3 = tf.nn.relu(self.conv4_3)
self.conv4_4 = self.conv_layer(self.relu4_3, 'conv4_4')
self.relu4_4 = tf.nn.relu(self.conv4_4)
self.pool4 = self.max_pool(self.relu4_4, 'pool4')
return self.conv4_4
def max_pool(self, bottom, name):
return tf.nn.max_pool(
bottom,
ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1],
padding='SAME',
name=name)
def conv_layer(self, bottom, name):
with tf.variable_scope(name):
filt = self.get_conv_filter(name)
conv = tf.nn.conv2d(bottom, filt, [1, 1, 1, 1], padding='SAME')
conv_biases = self.get_bias(name)
bias = tf.nn.bias_add(conv, conv_biases)
return bias
def fc_layer(self, bottom, name):
with tf.variable_scope(name):
shape = bottom.get_shape().as_list()
dim = 1
for d in shape[1:]:
dim *= d
x = tf.reshape(bottom, [-1, dim])
weights = self.get_fc_weight(name)
biases = self.get_bias(name)
fc = tf.nn.bias_add(tf.matmul(x, weights), biases)
return fc
def get_conv_filter(self, name):
return tf.constant(self.data_dict[name][0], name='filter')
def get_bias(self, name):
return tf.constant(self.data_dict[name][1], name='biases')
def get_fc_weight(self, name):
return tf.constant(self.data_dict[name][0], name='weights')
def content_loss(model_dir, input_photo, transfer_res, input_superpixel):
vgg_model = Vgg19(osp.join(model_dir, 'vgg19.npy'))
vgg_photo = vgg_model.build_conv4_4(input_photo)
vgg_output = vgg_model.build_conv4_4(transfer_res)
vgg_superpixel = vgg_model.build_conv4_4(input_superpixel)
h, w, c = vgg_photo.get_shape().as_list()[1:]
abs_photo = tf.losses.absolute_difference(vgg_photo, vgg_output)
photo_loss = tf.reduce_mean(abs_photo) / (h * w * c)
abs_superpixel = tf.losses.absolute_difference(vgg_superpixel, vgg_output)
superpixel_loss = tf.reduce_mean(abs_superpixel) / (h * w * c)
loss = photo_loss + superpixel_loss
return loss
def style_loss(input_cartoon, output_cartoon):
blur_fake = guided_filter(output_cartoon, output_cartoon, r=5, eps=2e-1)
blur_cartoon = guided_filter(input_cartoon, input_cartoon, r=5, eps=2e-1)
gray_fake, gray_cartoon = color_shift(output_cartoon, input_cartoon)
d_loss_gray, g_loss_gray = lsgan_loss(
disc_sn,
gray_cartoon,
gray_fake,
scale=1,
patch=True,
name='disc_gray')
d_loss_blur, g_loss_blur = lsgan_loss(
disc_sn,
blur_cartoon,
blur_fake,
scale=1,
patch=True,
name='disc_blur')
sty_g_loss = (g_loss_blur) + g_loss_gray
sty_d_loss = d_loss_blur + d_loss_gray
return sty_g_loss, sty_d_loss
def gan_loss(discriminator,
real,
fake,
scale=1,
channel=32,
patch=False,
name='discriminator'):
real_logit = discriminator(
real, scale, channel, name=name, patch=patch, reuse=False)
fake_logit = discriminator(
fake, scale, channel, name=name, patch=patch, reuse=True)
real_logit = tf.nn.sigmoid(real_logit)
fake_logit = tf.nn.sigmoid(fake_logit)
g_loss_blur = -tf.reduce_mean(tf.log(fake_logit))
d_loss_blur = -tf.reduce_mean(tf.log(real_logit) + tf.log(1. - fake_logit))
return d_loss_blur, g_loss_blur
def lsgan_loss(discriminator,
real,
fake,
scale=1,
channel=32,
patch=False,
name='discriminator'):
real_logit = discriminator(
real, scale, channel, name=name, patch=patch, reuse=False)
fake_logit = discriminator(
fake, scale, channel, name=name, patch=patch, reuse=True)
g_loss = tf.reduce_mean((fake_logit - 1)**2)
d_loss = 0.5 * (
tf.reduce_mean((real_logit - 1)**2) + tf.reduce_mean(fake_logit**2))
return d_loss, g_loss
def total_variation_loss(image, k_size=1):
h, w = image.get_shape().as_list()[1:3]
tv_h = tf.reduce_mean(
(image[:, k_size:, :, :] - image[:, :h - k_size, :, :])**2)
tv_w = tf.reduce_mean(
(image[:, :, k_size:, :] - image[:, :, :w - k_size, :])**2)
tv_loss = (tv_h + tv_w) / (3 * h * w)
return tv_loss
def guided_filter(x, y, r, eps=1e-2):
x_shape = tf.shape(x)
N = tf_box_filter(
tf.ones((1, x_shape[1], x_shape[2], 1), dtype=x.dtype), r)
mean_x = tf_box_filter(x, r) / N
mean_y = tf_box_filter(y, r) / N
cov_xy = tf_box_filter(x * y, r) / N - mean_x * mean_y
var_x = tf_box_filter(x * x, r) / N - mean_x * mean_x
A = cov_xy / (var_x + eps)
b = mean_y - A * mean_x
mean_A = tf_box_filter(A, r) / N
mean_b = tf_box_filter(b, r) / N
output = mean_A * x + mean_b
return output
def color_shift(image1, image2, mode='uniform'):
b1, g1, r1 = tf.split(image1, num_or_size_splits=3, axis=3)
b2, g2, r2 = tf.split(image2, num_or_size_splits=3, axis=3)
if mode == 'normal':
b_weight = tf.random.normal(shape=[1], mean=0.114, stddev=0.1)
g_weight = np.random.normal(shape=[1], mean=0.587, stddev=0.1)
r_weight = np.random.normal(shape=[1], mean=0.299, stddev=0.1)
elif mode == 'uniform':
b_weight = tf.random.uniform(shape=[1], minval=0.014, maxval=0.214)
g_weight = tf.random.uniform(shape=[1], minval=0.487, maxval=0.687)
r_weight = tf.random.uniform(shape=[1], minval=0.199, maxval=0.399)
output1 = (b_weight * b1 + g_weight * g1 + r_weight * r1) / (
b_weight + g_weight + r_weight)
output2 = (b_weight * b2 + g_weight * g2 + r_weight * r2) / (
b_weight + g_weight + r_weight)
return output1, output2
def simple_superpixel(batch_image, seg_num=200):
def process_slic(image):
seg_label = segmentation.slic(
image,
n_segments=seg_num,
sigma=1,
compactness=10,
convert2lab=True,
start_label=1)
image = color.label2rgb(seg_label, image, kind='avg', bg_label=0)
return image
num_job = np.shape(batch_image)[0]
batch_out = Parallel(n_jobs=num_job)(
delayed(process_slic)(image) for image in batch_image)
return np.array(batch_out)
def tf_box_filter(x, r):
ch = x.get_shape().as_list()[-1]
weight = 1 / ((2 * r + 1)**2)
box_kernel = weight * np.ones((2 * r + 1, 2 * r + 1, ch, 1))
box_kernel = np.array(box_kernel).astype(np.float32)
output = tf.nn.depthwise_conv2d(x, box_kernel, [1, 1, 1, 1], 'SAME')
return output
if __name__ == '__main__':
pass

View File

@@ -0,0 +1,64 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Dict
import tensorflow as tf
from modelscope.models.base import Model, Tensor
from .loss import content_loss, guided_filter, style_loss, total_variation_loss
from .network import unet_generator
class CartoonModel(Model):
def __init__(self, model_dir, *args, **kwargs):
super().__init__(model_dir, *args, **kwargs)
self.model_dir = model_dir
def __call__(
self,
input_photo: Dict[str, Tensor],
input_cartoon: Dict[str, Tensor] = None,
input_superpixel: Dict[str, Tensor] = None) -> Dict[str, Tensor]:
"""return the result by the model
Args:
input_photo: the preprocessed input photo image
input_cartoon: the preprocessed input cartoon image
input_superpixel: the computed input superpixel image
Returns:
output_dict: output dict of target ids
"""
if input_cartoon is None:
output = unet_generator(input_photo)
output_cartoon = guided_filter(input_photo, output, r=1)
return {'output_cartoon': output_cartoon}
else:
output = unet_generator(input_photo)
output_cartoon = guided_filter(input_photo, output, r=1)
con_loss = content_loss(self.model_dir, input_photo,
output_cartoon, input_superpixel)
sty_g_loss, sty_d_loss = style_loss(input_cartoon, output_cartoon)
tv_loss = total_variation_loss(output_cartoon)
g_loss = 1e-1 * sty_g_loss + 2e2 * con_loss + 1e4 * tv_loss
d_loss = sty_d_loss
return {
'output_cartoon': output_cartoon,
'g_loss': g_loss,
'd_loss': d_loss,
}
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""
Run the forward pass for a model.
Args:
input (Dict[str, Tensor]): the dict of the model inputs for the forward method
Returns:
Dict[str, Tensor]: output from the model forward pass
"""

View File

@@ -0,0 +1,146 @@
import tensorflow as tf
import tensorflow.contrib.slim as slim
def resblock(inputs, out_channel=32, name='resblock'):
with tf.variable_scope(name):
x = slim.convolution2d(
inputs, out_channel, [3, 3], activation_fn=None, scope='conv1')
x = tf.nn.leaky_relu(x)
x = slim.convolution2d(
x, out_channel, [3, 3], activation_fn=None, scope='conv2')
return x + inputs
def spectral_norm(w, iteration=1):
w_shape = w.shape.as_list()
w = tf.reshape(w, [-1, w_shape[-1]])
u = tf.get_variable(
'u', [1, w_shape[-1]],
initializer=tf.random_normal_initializer(),
trainable=False)
u_hat = u
v_hat = None
for i in range(iteration):
"""
power iteration
Usually iteration = 1 will be enough
"""
v_ = tf.matmul(u_hat, tf.transpose(w))
v_hat = tf.nn.l2_normalize(v_)
u_ = tf.matmul(v_hat, w)
u_hat = tf.nn.l2_normalize(u_)
u_hat = tf.stop_gradient(u_hat)
v_hat = tf.stop_gradient(v_hat)
sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat))
with tf.control_dependencies([u.assign(u_hat)]):
w_norm = w / sigma
w_norm = tf.reshape(w_norm, w_shape)
return w_norm
def conv_spectral_norm(x, channel, k_size, stride=1, name='conv_snorm'):
with tf.variable_scope(name):
w = tf.get_variable(
'kernel', shape=[k_size[0], k_size[1],
x.get_shape()[-1], channel])
b = tf.get_variable(
'bias', [channel], initializer=tf.constant_initializer(0.0))
x = tf.nn.conv2d(
input=x,
filter=spectral_norm(w),
strides=[1, stride, stride, 1],
padding='SAME') + b
return x
def unet_generator(inputs,
channel=32,
num_blocks=4,
name='generator',
reuse=False):
with tf.variable_scope(name, reuse=reuse):
x0 = slim.convolution2d(inputs, channel, [7, 7], activation_fn=None)
x0 = tf.nn.leaky_relu(x0)
x1 = slim.convolution2d(
x0, channel, [3, 3], stride=2, activation_fn=None)
x1 = tf.nn.leaky_relu(x1)
x1 = slim.convolution2d(x1, channel * 2, [3, 3], activation_fn=None)
x1 = tf.nn.leaky_relu(x1)
x2 = slim.convolution2d(
x1, channel * 2, [3, 3], stride=2, activation_fn=None)
x2 = tf.nn.leaky_relu(x2)
x2 = slim.convolution2d(x2, channel * 4, [3, 3], activation_fn=None)
x2 = tf.nn.leaky_relu(x2)
for idx in range(num_blocks):
x2 = resblock(
x2, out_channel=channel * 4, name='block_{}'.format(idx))
x2 = slim.convolution2d(x2, channel * 2, [3, 3], activation_fn=None)
x2 = tf.nn.leaky_relu(x2)
h1, w1 = tf.shape(x2)[1], tf.shape(x2)[2]
x3 = tf.image.resize_bilinear(x2, (h1 * 2, w1 * 2))
x3 = slim.convolution2d(
x3 + x1, channel * 2, [3, 3], activation_fn=None)
x3 = tf.nn.leaky_relu(x3)
x3 = slim.convolution2d(x3, channel, [3, 3], activation_fn=None)
x3 = tf.nn.leaky_relu(x3)
h2, w2 = tf.shape(x3)[1], tf.shape(x3)[2]
x4 = tf.image.resize_bilinear(x3, (h2 * 2, w2 * 2))
x4 = slim.convolution2d(x4 + x0, channel, [3, 3], activation_fn=None)
x4 = tf.nn.leaky_relu(x4)
x4 = slim.convolution2d(x4, 3, [7, 7], activation_fn=None)
# x4 = tf.clip_by_value(x4, -1, 1)
return x4
def disc_sn(x,
scale=1,
channel=32,
patch=True,
name='discriminator',
reuse=False):
with tf.variable_scope(name, reuse=reuse):
for idx in range(3):
x = conv_spectral_norm(
x,
channel * 2**idx, [3, 3],
stride=2,
name='conv{}_1'.format(idx))
x = tf.nn.leaky_relu(x)
x = conv_spectral_norm(
x, channel * 2**idx, [3, 3], name='conv{}_2'.format(idx))
x = tf.nn.leaky_relu(x)
if patch is True:
x = conv_spectral_norm(x, 1, [1, 1], name='conv_out')
else:
x = tf.reduce_mean(x, axis=[1, 2])
x = slim.fully_connected(x, 1, activation_fn=None)
return x
if __name__ == '__main__':
pass

View File

@@ -1,9 +1,11 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import random
import cv2
import numpy as np
import tensorflow as tf
def resize_size(image, size=720):
@@ -83,11 +85,87 @@ def find_pupil(landmarks, np_img):
return xm + xmin, ym + ymin
def next_batch(filename_list, batch_size, fineSize=256):
idx = np.arange(0, len(filename_list))
np.random.shuffle(idx)
idx = idx[:batch_size]
batch_data = []
for i in range(batch_size):
image = cv2.imread(filename_list[idx[i]])
h, w, c = image.shape
rw = random.randint(0, w - fineSize)
rh = random.randint(0, h - fineSize)
image = image[rh:rh + fineSize, rw:rw + fineSize, :]
image = image.astype(np.float32) / 127.5 - 1
batch_data.append(image)
return np.asarray(batch_data)
def read_image(image_path, IMAGE_SIZE=256):
image = tf.io.read_file(image_path)
image = tf.image.decode_image(image, channels=3)
image = tf.image.convert_image_dtype(image, tf.float32)
image.set_shape([None, None, 3])
image = tf.image.resize(images=image, size=[IMAGE_SIZE, IMAGE_SIZE])
image = image[..., ::-1]
# image = image / 127.5 - 1
image = (image - 0.5) * 2
return image
def load_data(photo_list):
photo = read_image(photo_list)
return photo
def tf_data_loader(image_list, batch_size):
dataset = tf.data.Dataset.from_tensor_slices((image_list))
dataset = dataset.shuffle(len(image_list))
dataset = dataset.map(load_data, num_parallel_calls=4)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(1)
return dataset
def write_batch_image(image, save_dir, name, n):
if not os.path.exists(save_dir):
os.makedirs(save_dir)
fused_dir = os.path.join(save_dir, name)
fused_image = [0] * n
for i in range(n):
fused_image[i] = []
for j in range(n):
k = i * n + j
image[k] = (image[k] + 1) * 127.5
image[k] = np.clip(image[k], 0, 255)
fused_image[i].append(image[k])
fused_image[i] = np.hstack(fused_image[i])
fused_image = np.vstack(fused_image)
cv2.imwrite(fused_dir, fused_image.astype(np.uint8))
def grid_batch_image(image, n):
fused_image = [0] * n
for i in range(n):
fused_image[i] = []
for j in range(n):
k = i * n + j
image[k] = (image[k] + 1) * 127.5
image[k] = np.clip(image[k], 0, 255)
fused_image[i].append(image[k])
fused_image[i] = np.hstack(fused_image[i])
fused_image = np.vstack(fused_image)
return fused_image
def all_file(file_dir):
L = []
for root, dirs, files in os.walk(file_dir):
for file in files:
extend = os.path.splitext(file)[1]
if extend == '.png' or extend == '.jpg' or extend == '.jpeg':
if extend == '.png' or extend == '.jpg' or extend == '.jpeg' or extend == '.JPG':
L.append(os.path.join(root, file))
return L

View File

@@ -11,6 +11,7 @@ if TYPE_CHECKING:
from .image_inpainting_trainer import ImageInpaintingTrainer
from .referring_video_object_segmentation_trainer import ReferringVideoObjectSegmentationTrainer
from .image_defrcn_fewshot_detection_trainer import ImageDefrcnFewshotTrainer
from .cartoon_translation_trainer import CartoonTranslationTrainer
else:
_import_structure = {
@@ -23,7 +24,8 @@ else:
'referring_video_object_segmentation_trainer':
['ReferringVideoObjectSegmentationTrainer'],
'image_defrcn_fewshot_detection_trainer':
['ImageDefrcnFewshotTrainer']
['ImageDefrcnFewshotTrainer'],
'cartoon_translation_trainer': ['CartoonTranslationTrainer']
}
import sys

View File

@@ -0,0 +1,241 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import os.path as osp
from typing import Dict, Optional
import numpy as np
import tensorflow as tf
from packaging import version
from tqdm import tqdm
from modelscope.models.cv.cartoon import (CartoonModel, all_file,
simple_superpixel, tf_data_loader,
write_batch_image)
from modelscope.trainers.base import BaseTrainer
from modelscope.trainers.builder import TRAINERS
from modelscope.utils.constant import ModelFile
from modelscope.utils.logger import get_logger
logger = get_logger()
if version.parse(tf.__version__) < version.parse('2'):
pass
else:
logger.info(
f'TensorFlow version {_tf_version} found, TF2.x is not supported by CartoonTranslationTrainer.'
)
@TRAINERS.register_module(module_name=r'cartoon-translation')
class CartoonTranslationTrainer(BaseTrainer):
def __init__(self,
model: str,
cfg_file: str = None,
work_dir=None,
photo=None,
cartoon=None,
max_steps=None,
*args,
**kwargs):
"""
Args:
model: the model_id of trained model
cfg_file: the path of configuration file
work_dir: the path to save training results
photo: the path of photo images for training
cartoon: the path of cartoon images for training
max_steps: the number of total iteration for training
Returns:
initialized trainer: object of CartoonTranslationTrainer
"""
model = self.get_or_download_model_dir(model)
tf.reset_default_graph()
self.model_dir = model
self.model_path = osp.join(model, ModelFile.TF_CHECKPOINT_FOLDER)
if cfg_file is None:
cfg_file = osp.join(model, ModelFile.CONFIGURATION)
super().__init__(cfg_file)
self.params = {}
self._override_params_from_file()
if work_dir is not None:
self.params['work_dir'] = work_dir
if photo is not None:
self.params['photo'] = photo
if cartoon is not None:
self.params['cartoon'] = cartoon
if max_steps is not None:
self.params['max_steps'] = max_steps
if not os.path.exists(self.params['work_dir']):
os.makedirs(self.params['work_dir'])
self.face_photo_list = all_file(self.params['photo'])
self.face_cartoon_list = all_file(self.params['cartoon'])
tf_config = tf.ConfigProto(allow_soft_placement=True)
tf_config.gpu_options.allow_growth = True
self._session = tf.Session(config=tf_config)
self.input_photo = tf.placeholder(tf.float32, [
self.params['batch_size'], self.params['patch_size'],
self.params['patch_size'], 3
])
self.input_superpixel = tf.placeholder(tf.float32, [
self.params['batch_size'], self.params['patch_size'],
self.params['patch_size'], 3
])
self.input_cartoon = tf.placeholder(tf.float32, [
self.params['batch_size'], self.params['patch_size'],
self.params['patch_size'], 3
])
self.model = CartoonModel(self.model_dir)
output = self.model(self.input_photo, self.input_cartoon,
self.input_superpixel)
self.output_cartoon = output['output_cartoon']
self.g_loss = output['g_loss']
self.d_loss = output['d_loss']
tf.summary.scalar('g_loss', self.g_loss)
tf.summary.scalar('d_loss', self.d_loss)
self.train_writer = tf.summary.FileWriter(self.params['work_dir']
+ '/train_log')
self.summary_op = tf.summary.merge_all()
all_vars = tf.trainable_variables()
gene_vars = [var for var in all_vars if 'gene' in var.name]
disc_vars = [var for var in all_vars if 'disc' in var.name]
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
self.g_optim = tf.train.AdamOptimizer(self.params['adv_train_lr'], beta1=0.5, beta2=0.99) \
.minimize(self.g_loss, var_list=gene_vars)
self.d_optim = tf.train.AdamOptimizer(self.params['adv_train_lr'], beta1=0.5, beta2=0.99) \
.minimize(self.d_loss, var_list=disc_vars)
self.saver = tf.train.Saver(max_to_keep=1000)
with self._session.as_default() as sess:
sess.run(tf.global_variables_initializer())
if self.params['resume_epoch'] != 0:
logger.info(f'loading model from {self.model_path}')
self.saver.restore(
sess,
osp.join(self.model_path,
'model-' + str(self.params['resume_epoch'])))
def _override_params_from_file(self):
self.params['photo'] = self.cfg['train']['photo']
self.params['cartoon'] = self.cfg['train']['cartoon']
self.params['patch_size'] = self.cfg['train']['patch_size']
self.params['work_dir'] = self.cfg['train']['work_dir']
self.params['batch_size'] = self.cfg['train']['batch_size']
self.params['adv_train_lr'] = self.cfg['train']['adv_train_lr']
self.params['max_steps'] = self.cfg['train']['max_steps']
self.params['logging_interval'] = self.cfg['train']['logging_interval']
self.params['ckpt_period_interval'] = self.cfg['train'][
'ckpt_period_interval']
self.params['resume_epoch'] = self.cfg['train']['resume_epoch']
self.params['num_gpus'] = self.cfg['train']['num_gpus']
def train(self, *args, **kwargs):
logger.info('Begin local cartoon translator training')
photo_ds = tf_data_loader(self.face_photo_list,
self.params['batch_size'])
cartoon_ds = tf_data_loader(self.face_cartoon_list,
self.params['batch_size'])
photo_iterator = photo_ds.make_initializable_iterator()
cartoon_iterator = cartoon_ds.make_initializable_iterator()
photo_next = photo_iterator.get_next()
cartoon_next = cartoon_iterator.get_next()
device = 'gpu:0' if tf.test.is_gpu_available else 'cpu:0'
with tf.device(device):
for max_steps in tqdm(range(self.params['max_steps'])):
self._session.run(photo_iterator.initializer)
self._session.run(cartoon_iterator.initializer)
photo_batch, cartoon_batch = self._session.run(
[photo_next, cartoon_next])
transfer_res = self._session.run(
self.output_cartoon,
feed_dict={self.input_photo: photo_batch})
input_superpixel = simple_superpixel(transfer_res, seg_num=200)
g_loss, _ = self._session.run(
[self.g_loss, self.g_optim],
feed_dict={
self.input_photo: photo_batch,
self.input_superpixel: input_superpixel,
self.input_cartoon: cartoon_batch
})
d_loss, _, train_info = self._session.run(
[self.d_loss, self.d_optim, self.summary_op],
feed_dict={
self.input_photo: photo_batch,
self.input_superpixel: input_superpixel,
self.input_cartoon: cartoon_batch
})
self.train_writer.add_summary(train_info, max_steps)
if np.mod(max_steps + 1, self.params['logging_interval']
) == 0 or max_steps == 0:
logger.info(
f'Iter: {max_steps}, d_loss: {d_loss}, g_loss: {g_loss}'
)
if np.mod(max_steps + 1,
self.params['ckpt_period_interval']
) == 0 or max_steps == 0:
self.saver.save(
self._session,
self.params['work_dir'] + '/saved_models/model',
write_meta_graph=False,
global_step=max_steps)
result_face = self._session.run(
self.output_cartoon,
feed_dict={
self.input_photo: photo_batch,
self.input_superpixel: photo_batch,
self.input_cartoon: cartoon_batch
})
write_batch_image(
result_face, self.params['work_dir'] + '/images',
str('%8d' % max_steps) + '_face_result.jpg', 4)
write_batch_image(
photo_batch, self.params['work_dir'] + '/images',
str('%8d' % max_steps) + '_face_photo.jpg', 4)
def evaluate(self,
checkpoint_path: Optional[str] = None,
*args,
**kwargs) -> Dict[str, float]:
"""evaluate a dataset
evaluate a dataset via a specific model from the `checkpoint_path` path, if the `checkpoint_path`
does not exist, read from the config file.
Args:
checkpoint_path (Optional[str], optional): the model path. Defaults to None.
Returns:
Dict[str, float]: the results about the evaluation
Example:
{"accuracy": 0.5091743119266054, "f1": 0.673780487804878}
"""
raise NotImplementedError(
'evaluate is not supported by CartoonTranslationTrainer')

View File

@@ -59,6 +59,7 @@ isolated: # test cases that may require excessive anmount of GPU memory or run
- test_video_deinterlace.py
- test_image_inpainting_sdv2.py
- test_bad_image_detecting.py
- test_image_portrait_stylization_trainer.py
- test_controllable_image_generation.py
envs:
@@ -80,3 +81,4 @@ envs:
- test_person_image_cartoon.py
- test_skin_retouching.py
- test_image_style_transfer.py
- test_image_portrait_stylization_trainer.py

View File

@@ -0,0 +1,61 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import unittest
import cv2
from modelscope.exporters.cv import CartoonTranslationExporter
from modelscope.msdatasets import MsDataset
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.pipelines.base import Pipeline
from modelscope.trainers.cv import CartoonTranslationTrainer
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level
class TestImagePortraitStylizationTrainer(unittest.TestCase):
def setUp(self) -> None:
self.task = Tasks.image_portrait_stylization
self.test_image = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/image_cartoon.png'
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name(self):
model_id = 'damo/cv_unet_person-image-cartoon_compound-models'
data_dir = MsDataset.load(
'dctnet_train_clipart_mini_ms',
namespace='menyifang',
split='train').config_kwargs['split_config']['train']
data_photo = os.path.join(data_dir, 'face_photo')
data_cartoon = os.path.join(data_dir, 'face_cartoon')
work_dir = 'exp_localtoon'
max_steps = 10
trainer = CartoonTranslationTrainer(
model=model_id,
work_dir=work_dir,
photo=data_photo,
cartoon=data_cartoon,
max_steps=max_steps)
trainer.train()
ckpt_path = os.path.join(work_dir, 'saved_models', 'model-' + str(0))
pb_path = os.path.join(trainer.model_dir, 'cartoon_h.pb')
exporter = CartoonTranslationExporter()
exporter.export_frozen_graph_def(
ckpt_path=ckpt_path, frozen_graph_path=pb_path)
self.pipeline_person_image_cartoon(trainer.model_dir)
def pipeline_person_image_cartoon(self, model_dir):
pipeline_cartoon = pipeline(task=self.task, model=model_dir)
result = pipeline_cartoon(input=self.test_image)
if result is not None:
cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG])
print(f'Output written to {os.path.abspath("result.png")}')
if __name__ == '__main__':
unittest.main()