From e63593f3bbce32b86d453bb5075ba956de742427 Mon Sep 17 00:00:00 2001 From: myf272609 Date: Tue, 28 Feb 2023 17:01:34 +0800 Subject: [PATCH] [to #42322933] add fintune support for cartoon task MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 人像卡通化模型增加训练支持 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11675597 * add fintune support for cartoon --- modelscope/exporters/__init__.py | 1 + modelscope/exporters/cv/__init__.py | 6 + .../cv/cartoon_translation_exporter.py | 72 +++++ modelscope/models/cv/cartoon/__init__.py | 14 +- modelscope/models/cv/cartoon/loss.py | 274 ++++++++++++++++++ modelscope/models/cv/cartoon/model_tf.py | 64 ++++ modelscope/models/cv/cartoon/network.py | 146 ++++++++++ modelscope/models/cv/cartoon/utils.py | 80 ++++- modelscope/trainers/cv/__init__.py | 4 +- .../cv/cartoon_translation_trainer.py | 241 +++++++++++++++ tests/run_config.yaml | 2 + ...test_image_portrait_stylization_trainer.py | 61 ++++ 12 files changed, 961 insertions(+), 4 deletions(-) create mode 100644 modelscope/exporters/cv/__init__.py create mode 100644 modelscope/exporters/cv/cartoon_translation_exporter.py create mode 100644 modelscope/models/cv/cartoon/loss.py create mode 100644 modelscope/models/cv/cartoon/model_tf.py create mode 100644 modelscope/models/cv/cartoon/network.py create mode 100644 modelscope/trainers/cv/cartoon_translation_trainer.py create mode 100644 tests/trainers/test_image_portrait_stylization_trainer.py diff --git a/modelscope/exporters/__init__.py b/modelscope/exporters/__init__.py index 0c773dca..62f89f2a 100644 --- a/modelscope/exporters/__init__.py +++ b/modelscope/exporters/__init__.py @@ -5,6 +5,7 @@ from .base import Exporter from .builder import build_exporter if is_tf_available(): + from .cv import CartoonTranslationExporter from .nlp import CsanmtForTranslationExporter from .tf_model_exporter import TfModelExporter if is_torch_available(): diff --git a/modelscope/exporters/cv/__init__.py b/modelscope/exporters/cv/__init__.py new file mode 100644 index 00000000..956c3cb2 --- /dev/null +++ b/modelscope/exporters/cv/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from modelscope.utils.import_utils import is_tf_available + +if is_tf_available(): + from .cartoon_translation_exporter import CartoonTranslationExporter diff --git a/modelscope/exporters/cv/cartoon_translation_exporter.py b/modelscope/exporters/cv/cartoon_translation_exporter.py new file mode 100644 index 00000000..79b859cb --- /dev/null +++ b/modelscope/exporters/cv/cartoon_translation_exporter.py @@ -0,0 +1,72 @@ +import os +from typing import Any, Dict + +import tensorflow as tf +from packaging import version + +from modelscope.exporters.builder import EXPORTERS +from modelscope.exporters.tf_model_exporter import TfModelExporter +from modelscope.models.cv.cartoon import CartoonModel +from modelscope.utils.logger import get_logger + +logger = get_logger(__name__) + +if version.parse(tf.__version__) < version.parse('2'): + pass +else: + logger.info( + f'TensorFlow version {_tf_version} found, TF2.x is not supported by CartoonTranslationExporter.' + ) + +tf.logging.set_verbosity(tf.logging.INFO) + + +@EXPORTERS.register_module(module_name=r'cartoon-translation') +class CartoonTranslationExporter(TfModelExporter): + + def __init__(self, model=None): + super().__init__(model) + + def export_frozen_graph_def(self, ckpt_path: str, frozen_graph_path: str, + **kwargs): + tf.get_variable_scope().reuse_variables() + + input = tf.placeholder(tf.float32, [None, None, 3], name='input_image') + input = input[:, :, :][tf.newaxis] + input = input / 127.5 - 1.0 + + model = CartoonModel(model_dir='') + output = model(input) + final_out = output['output_cartoon'][0] + final_out = tf.clip_by_value(final_out, -0.999999, 0.999999) + final_out = (final_out + 1.0) * 127.5 + final_out = tf.cast(final_out, tf.uint8, name='output_image') + + all_vars = tf.trainable_variables() + gene_vars = [var for var in all_vars if 'generator' in var.name] + saver = tf.train.Saver(var_list=gene_vars) + + init = tf.global_variables_initializer() + config = tf.ConfigProto(allow_soft_placement=True) + config.gpu_options.allow_growth = True + with tf.Session(config=config) as sess: + + sess.run(init) + saver.restore(sess, ckpt_path) + frozen_graph_def = tf.graph_util.convert_variables_to_constants( + sess, sess.graph_def, output_node_names=['output_image']) + with open(frozen_graph_path, 'wb') as f: + f.write(frozen_graph_def.SerializeToString()) + print('freeze done') + + return {'model': frozen_graph_path} + + def export_saved_model(self, output_dir: str, **kwargs): + raise NotImplementedError( + 'Exporting saved model is not supported by CartoonTranslationExporter currently.' + ) + + def export_onnx(self, output_dir: str, **kwargs): + raise NotImplementedError( + 'Exporting onnx model is not supported by CartoonTranslationExporter currently.' + ) diff --git a/modelscope/models/cv/cartoon/__init__.py b/modelscope/models/cv/cartoon/__init__.py index 131f5cac..447335c6 100644 --- a/modelscope/models/cv/cartoon/__init__.py +++ b/modelscope/models/cv/cartoon/__init__.py @@ -7,14 +7,24 @@ if TYPE_CHECKING: from .facelib.facer import FaceAna from .mtcnn_pytorch.src.align_trans import (get_reference_facial_points, warp_and_crop_face) - from .utils import (get_f5p, padTo16x, resize_size) + from .utils import (get_f5p, padTo16x, resize_size, all_file, + tf_data_loader, write_batch_image) + from .network import disc_sn + from .loss import simple_superpixel + from .model_tf import CartoonModel else: _import_structure = { 'facelib.facer': ['FaceAna'], 'mtcnn_pytorch.src.align_trans': ['get_reference_facial_points', 'warp_and_crop_face'], - 'utils': ['get_f5p', 'padTo16x', 'resize_size'] + 'utils': [ + 'get_f5p', 'padTo16x', 'resize_size', 'all_file', 'tf_data_loader', + 'write_batch_image' + ], + 'network': ['disc_sn'], + 'loss': ['simple_superpixel'], + 'model_tf': ['CartoonModel'], } import sys diff --git a/modelscope/models/cv/cartoon/loss.py b/modelscope/models/cv/cartoon/loss.py new file mode 100644 index 00000000..9010d29a --- /dev/null +++ b/modelscope/models/cv/cartoon/loss.py @@ -0,0 +1,274 @@ +''' +CVPR 2020 submission, Paper ID 6791 +Source code for 'Learning to Cartoonize Using White-Box Cartoon Representations' +''' + +import os.path as osp + +import numpy as np +import scipy.stats as st +import tensorflow as tf +from joblib import Parallel, delayed +from skimage import color, segmentation + +from .network import disc_sn + +VGG_MEAN = [103.939, 116.779, 123.68] + + +class Vgg19: + + def __init__(self, vgg19_npy_path=None): + + self.data_dict = np.load( + vgg19_npy_path, encoding='latin1', allow_pickle=True).item() + + print('Finished loading vgg19.npy') + + def build_conv4_4(self, rgb, include_fc=False): + + rgb_scaled = (rgb + 1) * 127.5 + + blue, green, red = tf.split( + axis=3, num_or_size_splits=3, value=rgb_scaled) + bgr = tf.concat( + axis=3, + values=[ + blue - VGG_MEAN[0], green - VGG_MEAN[1], red - VGG_MEAN[2] + ]) + + self.conv1_1 = self.conv_layer(bgr, 'conv1_1') + self.relu1_1 = tf.nn.relu(self.conv1_1) + self.conv1_2 = self.conv_layer(self.relu1_1, 'conv1_2') + self.relu1_2 = tf.nn.relu(self.conv1_2) + self.pool1 = self.max_pool(self.relu1_2, 'pool1') + + self.conv2_1 = self.conv_layer(self.pool1, 'conv2_1') + self.relu2_1 = tf.nn.relu(self.conv2_1) + self.conv2_2 = self.conv_layer(self.relu2_1, 'conv2_2') + self.relu2_2 = tf.nn.relu(self.conv2_2) + self.pool2 = self.max_pool(self.relu2_2, 'pool2') + + self.conv3_1 = self.conv_layer(self.pool2, 'conv3_1') + self.relu3_1 = tf.nn.relu(self.conv3_1) + self.conv3_2 = self.conv_layer(self.relu3_1, 'conv3_2') + self.relu3_2 = tf.nn.relu(self.conv3_2) + self.conv3_3 = self.conv_layer(self.relu3_2, 'conv3_3') + self.relu3_3 = tf.nn.relu(self.conv3_3) + self.conv3_4 = self.conv_layer(self.relu3_3, 'conv3_4') + self.relu3_4 = tf.nn.relu(self.conv3_4) + self.pool3 = self.max_pool(self.relu3_4, 'pool3') + + self.conv4_1 = self.conv_layer(self.pool3, 'conv4_1') + self.relu4_1 = tf.nn.relu(self.conv4_1) + self.conv4_2 = self.conv_layer(self.relu4_1, 'conv4_2') + self.relu4_2 = tf.nn.relu(self.conv4_2) + self.conv4_3 = self.conv_layer(self.relu4_2, 'conv4_3') + self.relu4_3 = tf.nn.relu(self.conv4_3) + self.conv4_4 = self.conv_layer(self.relu4_3, 'conv4_4') + self.relu4_4 = tf.nn.relu(self.conv4_4) + self.pool4 = self.max_pool(self.relu4_4, 'pool4') + + return self.conv4_4 + + def max_pool(self, bottom, name): + return tf.nn.max_pool( + bottom, + ksize=[1, 2, 2, 1], + strides=[1, 2, 2, 1], + padding='SAME', + name=name) + + def conv_layer(self, bottom, name): + with tf.variable_scope(name): + filt = self.get_conv_filter(name) + conv = tf.nn.conv2d(bottom, filt, [1, 1, 1, 1], padding='SAME') + conv_biases = self.get_bias(name) + bias = tf.nn.bias_add(conv, conv_biases) + return bias + + def fc_layer(self, bottom, name): + with tf.variable_scope(name): + shape = bottom.get_shape().as_list() + dim = 1 + for d in shape[1:]: + dim *= d + x = tf.reshape(bottom, [-1, dim]) + weights = self.get_fc_weight(name) + biases = self.get_bias(name) + fc = tf.nn.bias_add(tf.matmul(x, weights), biases) + + return fc + + def get_conv_filter(self, name): + return tf.constant(self.data_dict[name][0], name='filter') + + def get_bias(self, name): + return tf.constant(self.data_dict[name][1], name='biases') + + def get_fc_weight(self, name): + return tf.constant(self.data_dict[name][0], name='weights') + + +def content_loss(model_dir, input_photo, transfer_res, input_superpixel): + vgg_model = Vgg19(osp.join(model_dir, 'vgg19.npy')) + vgg_photo = vgg_model.build_conv4_4(input_photo) + vgg_output = vgg_model.build_conv4_4(transfer_res) + vgg_superpixel = vgg_model.build_conv4_4(input_superpixel) + h, w, c = vgg_photo.get_shape().as_list()[1:] + abs_photo = tf.losses.absolute_difference(vgg_photo, vgg_output) + photo_loss = tf.reduce_mean(abs_photo) / (h * w * c) + abs_superpixel = tf.losses.absolute_difference(vgg_superpixel, vgg_output) + superpixel_loss = tf.reduce_mean(abs_superpixel) / (h * w * c) + loss = photo_loss + superpixel_loss + + return loss + + +def style_loss(input_cartoon, output_cartoon): + blur_fake = guided_filter(output_cartoon, output_cartoon, r=5, eps=2e-1) + blur_cartoon = guided_filter(input_cartoon, input_cartoon, r=5, eps=2e-1) + + gray_fake, gray_cartoon = color_shift(output_cartoon, input_cartoon) + + d_loss_gray, g_loss_gray = lsgan_loss( + disc_sn, + gray_cartoon, + gray_fake, + scale=1, + patch=True, + name='disc_gray') + d_loss_blur, g_loss_blur = lsgan_loss( + disc_sn, + blur_cartoon, + blur_fake, + scale=1, + patch=True, + name='disc_blur') + sty_g_loss = (g_loss_blur) + g_loss_gray + sty_d_loss = d_loss_blur + d_loss_gray + + return sty_g_loss, sty_d_loss + + +def gan_loss(discriminator, + real, + fake, + scale=1, + channel=32, + patch=False, + name='discriminator'): + + real_logit = discriminator( + real, scale, channel, name=name, patch=patch, reuse=False) + fake_logit = discriminator( + fake, scale, channel, name=name, patch=patch, reuse=True) + + real_logit = tf.nn.sigmoid(real_logit) + fake_logit = tf.nn.sigmoid(fake_logit) + + g_loss_blur = -tf.reduce_mean(tf.log(fake_logit)) + d_loss_blur = -tf.reduce_mean(tf.log(real_logit) + tf.log(1. - fake_logit)) + + return d_loss_blur, g_loss_blur + + +def lsgan_loss(discriminator, + real, + fake, + scale=1, + channel=32, + patch=False, + name='discriminator'): + + real_logit = discriminator( + real, scale, channel, name=name, patch=patch, reuse=False) + fake_logit = discriminator( + fake, scale, channel, name=name, patch=patch, reuse=True) + + g_loss = tf.reduce_mean((fake_logit - 1)**2) + d_loss = 0.5 * ( + tf.reduce_mean((real_logit - 1)**2) + tf.reduce_mean(fake_logit**2)) + + return d_loss, g_loss + + +def total_variation_loss(image, k_size=1): + h, w = image.get_shape().as_list()[1:3] + tv_h = tf.reduce_mean( + (image[:, k_size:, :, :] - image[:, :h - k_size, :, :])**2) + tv_w = tf.reduce_mean( + (image[:, :, k_size:, :] - image[:, :, :w - k_size, :])**2) + tv_loss = (tv_h + tv_w) / (3 * h * w) + return tv_loss + + +def guided_filter(x, y, r, eps=1e-2): + x_shape = tf.shape(x) + N = tf_box_filter( + tf.ones((1, x_shape[1], x_shape[2], 1), dtype=x.dtype), r) + + mean_x = tf_box_filter(x, r) / N + mean_y = tf_box_filter(y, r) / N + cov_xy = tf_box_filter(x * y, r) / N - mean_x * mean_y + var_x = tf_box_filter(x * x, r) / N - mean_x * mean_x + + A = cov_xy / (var_x + eps) + b = mean_y - A * mean_x + + mean_A = tf_box_filter(A, r) / N + mean_b = tf_box_filter(b, r) / N + + output = mean_A * x + mean_b + + return output + + +def color_shift(image1, image2, mode='uniform'): + b1, g1, r1 = tf.split(image1, num_or_size_splits=3, axis=3) + b2, g2, r2 = tf.split(image2, num_or_size_splits=3, axis=3) + if mode == 'normal': + b_weight = tf.random.normal(shape=[1], mean=0.114, stddev=0.1) + g_weight = np.random.normal(shape=[1], mean=0.587, stddev=0.1) + r_weight = np.random.normal(shape=[1], mean=0.299, stddev=0.1) + elif mode == 'uniform': + b_weight = tf.random.uniform(shape=[1], minval=0.014, maxval=0.214) + g_weight = tf.random.uniform(shape=[1], minval=0.487, maxval=0.687) + r_weight = tf.random.uniform(shape=[1], minval=0.199, maxval=0.399) + output1 = (b_weight * b1 + g_weight * g1 + r_weight * r1) / ( + b_weight + g_weight + r_weight) + output2 = (b_weight * b2 + g_weight * g2 + r_weight * r2) / ( + b_weight + g_weight + r_weight) + return output1, output2 + + +def simple_superpixel(batch_image, seg_num=200): + + def process_slic(image): + seg_label = segmentation.slic( + image, + n_segments=seg_num, + sigma=1, + compactness=10, + convert2lab=True, + start_label=1) + image = color.label2rgb(seg_label, image, kind='avg', bg_label=0) + return image + + num_job = np.shape(batch_image)[0] + batch_out = Parallel(n_jobs=num_job)( + delayed(process_slic)(image) for image in batch_image) + return np.array(batch_out) + + +def tf_box_filter(x, r): + ch = x.get_shape().as_list()[-1] + weight = 1 / ((2 * r + 1)**2) + box_kernel = weight * np.ones((2 * r + 1, 2 * r + 1, ch, 1)) + box_kernel = np.array(box_kernel).astype(np.float32) + output = tf.nn.depthwise_conv2d(x, box_kernel, [1, 1, 1, 1], 'SAME') + return output + + +if __name__ == '__main__': + pass diff --git a/modelscope/models/cv/cartoon/model_tf.py b/modelscope/models/cv/cartoon/model_tf.py new file mode 100644 index 00000000..4d7cf4ed --- /dev/null +++ b/modelscope/models/cv/cartoon/model_tf.py @@ -0,0 +1,64 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Dict + +import tensorflow as tf + +from modelscope.models.base import Model, Tensor +from .loss import content_loss, guided_filter, style_loss, total_variation_loss +from .network import unet_generator + + +class CartoonModel(Model): + + def __init__(self, model_dir, *args, **kwargs): + super().__init__(model_dir, *args, **kwargs) + self.model_dir = model_dir + + def __call__( + self, + input_photo: Dict[str, Tensor], + input_cartoon: Dict[str, Tensor] = None, + input_superpixel: Dict[str, Tensor] = None) -> Dict[str, Tensor]: + """return the result by the model + + Args: + input_photo: the preprocessed input photo image + input_cartoon: the preprocessed input cartoon image + input_superpixel: the computed input superpixel image + + Returns: + output_dict: output dict of target ids + """ + if input_cartoon is None: + output = unet_generator(input_photo) + output_cartoon = guided_filter(input_photo, output, r=1) + return {'output_cartoon': output_cartoon} + else: + output = unet_generator(input_photo) + output_cartoon = guided_filter(input_photo, output, r=1) + + con_loss = content_loss(self.model_dir, input_photo, + output_cartoon, input_superpixel) + sty_g_loss, sty_d_loss = style_loss(input_cartoon, output_cartoon) + tv_loss = total_variation_loss(output_cartoon) + + g_loss = 1e-1 * sty_g_loss + 2e2 * con_loss + 1e4 * tv_loss + d_loss = sty_d_loss + + return { + 'output_cartoon': output_cartoon, + 'g_loss': g_loss, + 'd_loss': d_loss, + } + + def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + """ + Run the forward pass for a model. + + Args: + input (Dict[str, Tensor]): the dict of the model inputs for the forward method + + Returns: + Dict[str, Tensor]: output from the model forward pass + """ diff --git a/modelscope/models/cv/cartoon/network.py b/modelscope/models/cv/cartoon/network.py new file mode 100644 index 00000000..6a00a506 --- /dev/null +++ b/modelscope/models/cv/cartoon/network.py @@ -0,0 +1,146 @@ +import tensorflow as tf +import tensorflow.contrib.slim as slim + + +def resblock(inputs, out_channel=32, name='resblock'): + with tf.variable_scope(name): + x = slim.convolution2d( + inputs, out_channel, [3, 3], activation_fn=None, scope='conv1') + + x = tf.nn.leaky_relu(x) + x = slim.convolution2d( + x, out_channel, [3, 3], activation_fn=None, scope='conv2') + + return x + inputs + + +def spectral_norm(w, iteration=1): + w_shape = w.shape.as_list() + w = tf.reshape(w, [-1, w_shape[-1]]) + + u = tf.get_variable( + 'u', [1, w_shape[-1]], + initializer=tf.random_normal_initializer(), + trainable=False) + + u_hat = u + v_hat = None + for i in range(iteration): + """ + power iteration + Usually iteration = 1 will be enough + """ + v_ = tf.matmul(u_hat, tf.transpose(w)) + v_hat = tf.nn.l2_normalize(v_) + + u_ = tf.matmul(v_hat, w) + u_hat = tf.nn.l2_normalize(u_) + + u_hat = tf.stop_gradient(u_hat) + v_hat = tf.stop_gradient(v_hat) + + sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat)) + + with tf.control_dependencies([u.assign(u_hat)]): + w_norm = w / sigma + w_norm = tf.reshape(w_norm, w_shape) + + return w_norm + + +def conv_spectral_norm(x, channel, k_size, stride=1, name='conv_snorm'): + with tf.variable_scope(name): + w = tf.get_variable( + 'kernel', shape=[k_size[0], k_size[1], + x.get_shape()[-1], channel]) + b = tf.get_variable( + 'bias', [channel], initializer=tf.constant_initializer(0.0)) + + x = tf.nn.conv2d( + input=x, + filter=spectral_norm(w), + strides=[1, stride, stride, 1], + padding='SAME') + b + + return x + + +def unet_generator(inputs, + channel=32, + num_blocks=4, + name='generator', + reuse=False): + with tf.variable_scope(name, reuse=reuse): + + x0 = slim.convolution2d(inputs, channel, [7, 7], activation_fn=None) + x0 = tf.nn.leaky_relu(x0) + + x1 = slim.convolution2d( + x0, channel, [3, 3], stride=2, activation_fn=None) + x1 = tf.nn.leaky_relu(x1) + x1 = slim.convolution2d(x1, channel * 2, [3, 3], activation_fn=None) + x1 = tf.nn.leaky_relu(x1) + + x2 = slim.convolution2d( + x1, channel * 2, [3, 3], stride=2, activation_fn=None) + x2 = tf.nn.leaky_relu(x2) + x2 = slim.convolution2d(x2, channel * 4, [3, 3], activation_fn=None) + x2 = tf.nn.leaky_relu(x2) + + for idx in range(num_blocks): + x2 = resblock( + x2, out_channel=channel * 4, name='block_{}'.format(idx)) + + x2 = slim.convolution2d(x2, channel * 2, [3, 3], activation_fn=None) + x2 = tf.nn.leaky_relu(x2) + + h1, w1 = tf.shape(x2)[1], tf.shape(x2)[2] + x3 = tf.image.resize_bilinear(x2, (h1 * 2, w1 * 2)) + x3 = slim.convolution2d( + x3 + x1, channel * 2, [3, 3], activation_fn=None) + x3 = tf.nn.leaky_relu(x3) + x3 = slim.convolution2d(x3, channel, [3, 3], activation_fn=None) + x3 = tf.nn.leaky_relu(x3) + + h2, w2 = tf.shape(x3)[1], tf.shape(x3)[2] + x4 = tf.image.resize_bilinear(x3, (h2 * 2, w2 * 2)) + x4 = slim.convolution2d(x4 + x0, channel, [3, 3], activation_fn=None) + x4 = tf.nn.leaky_relu(x4) + x4 = slim.convolution2d(x4, 3, [7, 7], activation_fn=None) + + # x4 = tf.clip_by_value(x4, -1, 1) + return x4 + + +def disc_sn(x, + scale=1, + channel=32, + patch=True, + name='discriminator', + reuse=False): + with tf.variable_scope(name, reuse=reuse): + + for idx in range(3): + x = conv_spectral_norm( + x, + channel * 2**idx, [3, 3], + stride=2, + name='conv{}_1'.format(idx)) + x = tf.nn.leaky_relu(x) + + x = conv_spectral_norm( + x, channel * 2**idx, [3, 3], name='conv{}_2'.format(idx)) + x = tf.nn.leaky_relu(x) + + if patch is True: + x = conv_spectral_norm(x, 1, [1, 1], name='conv_out') + + else: + x = tf.reduce_mean(x, axis=[1, 2]) + x = slim.fully_connected(x, 1, activation_fn=None) + + return x + + +if __name__ == '__main__': + pass diff --git a/modelscope/models/cv/cartoon/utils.py b/modelscope/models/cv/cartoon/utils.py index 59b4e879..f93e3324 100644 --- a/modelscope/models/cv/cartoon/utils.py +++ b/modelscope/models/cv/cartoon/utils.py @@ -1,9 +1,11 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os +import random import cv2 import numpy as np +import tensorflow as tf def resize_size(image, size=720): @@ -83,11 +85,87 @@ def find_pupil(landmarks, np_img): return xm + xmin, ym + ymin +def next_batch(filename_list, batch_size, fineSize=256): + idx = np.arange(0, len(filename_list)) + np.random.shuffle(idx) + idx = idx[:batch_size] + batch_data = [] + for i in range(batch_size): + image = cv2.imread(filename_list[idx[i]]) + h, w, c = image.shape + rw = random.randint(0, w - fineSize) + rh = random.randint(0, h - fineSize) + image = image[rh:rh + fineSize, rw:rw + fineSize, :] + image = image.astype(np.float32) / 127.5 - 1 + batch_data.append(image) + + return np.asarray(batch_data) + + +def read_image(image_path, IMAGE_SIZE=256): + image = tf.io.read_file(image_path) + image = tf.image.decode_image(image, channels=3) + image = tf.image.convert_image_dtype(image, tf.float32) + image.set_shape([None, None, 3]) + image = tf.image.resize(images=image, size=[IMAGE_SIZE, IMAGE_SIZE]) + image = image[..., ::-1] + # image = image / 127.5 - 1 + image = (image - 0.5) * 2 + + return image + + +def load_data(photo_list): + photo = read_image(photo_list) + return photo + + +def tf_data_loader(image_list, batch_size): + dataset = tf.data.Dataset.from_tensor_slices((image_list)) + dataset = dataset.shuffle(len(image_list)) + dataset = dataset.map(load_data, num_parallel_calls=4) + dataset = dataset.batch(batch_size) + dataset = dataset.prefetch(1) + return dataset + + +def write_batch_image(image, save_dir, name, n): + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + fused_dir = os.path.join(save_dir, name) + fused_image = [0] * n + for i in range(n): + fused_image[i] = [] + for j in range(n): + k = i * n + j + image[k] = (image[k] + 1) * 127.5 + image[k] = np.clip(image[k], 0, 255) + fused_image[i].append(image[k]) + fused_image[i] = np.hstack(fused_image[i]) + fused_image = np.vstack(fused_image) + cv2.imwrite(fused_dir, fused_image.astype(np.uint8)) + + +def grid_batch_image(image, n): + fused_image = [0] * n + for i in range(n): + fused_image[i] = [] + for j in range(n): + k = i * n + j + image[k] = (image[k] + 1) * 127.5 + image[k] = np.clip(image[k], 0, 255) + fused_image[i].append(image[k]) + fused_image[i] = np.hstack(fused_image[i]) + fused_image = np.vstack(fused_image) + return fused_image + + def all_file(file_dir): L = [] for root, dirs, files in os.walk(file_dir): for file in files: extend = os.path.splitext(file)[1] - if extend == '.png' or extend == '.jpg' or extend == '.jpeg': + if extend == '.png' or extend == '.jpg' or extend == '.jpeg' or extend == '.JPG': L.append(os.path.join(root, file)) return L diff --git a/modelscope/trainers/cv/__init__.py b/modelscope/trainers/cv/__init__.py index 2f682b81..78023c2c 100644 --- a/modelscope/trainers/cv/__init__.py +++ b/modelscope/trainers/cv/__init__.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: from .image_inpainting_trainer import ImageInpaintingTrainer from .referring_video_object_segmentation_trainer import ReferringVideoObjectSegmentationTrainer from .image_defrcn_fewshot_detection_trainer import ImageDefrcnFewshotTrainer + from .cartoon_translation_trainer import CartoonTranslationTrainer else: _import_structure = { @@ -23,7 +24,8 @@ else: 'referring_video_object_segmentation_trainer': ['ReferringVideoObjectSegmentationTrainer'], 'image_defrcn_fewshot_detection_trainer': - ['ImageDefrcnFewshotTrainer'] + ['ImageDefrcnFewshotTrainer'], + 'cartoon_translation_trainer': ['CartoonTranslationTrainer'] } import sys diff --git a/modelscope/trainers/cv/cartoon_translation_trainer.py b/modelscope/trainers/cv/cartoon_translation_trainer.py new file mode 100644 index 00000000..0e7021da --- /dev/null +++ b/modelscope/trainers/cv/cartoon_translation_trainer.py @@ -0,0 +1,241 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import os.path as osp +from typing import Dict, Optional + +import numpy as np +import tensorflow as tf +from packaging import version +from tqdm import tqdm + +from modelscope.models.cv.cartoon import (CartoonModel, all_file, + simple_superpixel, tf_data_loader, + write_batch_image) +from modelscope.trainers.base import BaseTrainer +from modelscope.trainers.builder import TRAINERS +from modelscope.utils.constant import ModelFile +from modelscope.utils.logger import get_logger + +logger = get_logger() + +if version.parse(tf.__version__) < version.parse('2'): + pass +else: + logger.info( + f'TensorFlow version {_tf_version} found, TF2.x is not supported by CartoonTranslationTrainer.' + ) + + +@TRAINERS.register_module(module_name=r'cartoon-translation') +class CartoonTranslationTrainer(BaseTrainer): + + def __init__(self, + model: str, + cfg_file: str = None, + work_dir=None, + photo=None, + cartoon=None, + max_steps=None, + *args, + **kwargs): + """ + Args: + model: the model_id of trained model + cfg_file: the path of configuration file + work_dir: the path to save training results + photo: the path of photo images for training + cartoon: the path of cartoon images for training + max_steps: the number of total iteration for training + Returns: + initialized trainer: object of CartoonTranslationTrainer + """ + model = self.get_or_download_model_dir(model) + tf.reset_default_graph() + + self.model_dir = model + self.model_path = osp.join(model, ModelFile.TF_CHECKPOINT_FOLDER) + if cfg_file is None: + cfg_file = osp.join(model, ModelFile.CONFIGURATION) + + super().__init__(cfg_file) + + self.params = {} + self._override_params_from_file() + if work_dir is not None: + self.params['work_dir'] = work_dir + if photo is not None: + self.params['photo'] = photo + if cartoon is not None: + self.params['cartoon'] = cartoon + if max_steps is not None: + self.params['max_steps'] = max_steps + + if not os.path.exists(self.params['work_dir']): + os.makedirs(self.params['work_dir']) + + self.face_photo_list = all_file(self.params['photo']) + self.face_cartoon_list = all_file(self.params['cartoon']) + + tf_config = tf.ConfigProto(allow_soft_placement=True) + tf_config.gpu_options.allow_growth = True + self._session = tf.Session(config=tf_config) + + self.input_photo = tf.placeholder(tf.float32, [ + self.params['batch_size'], self.params['patch_size'], + self.params['patch_size'], 3 + ]) + self.input_superpixel = tf.placeholder(tf.float32, [ + self.params['batch_size'], self.params['patch_size'], + self.params['patch_size'], 3 + ]) + self.input_cartoon = tf.placeholder(tf.float32, [ + self.params['batch_size'], self.params['patch_size'], + self.params['patch_size'], 3 + ]) + + self.model = CartoonModel(self.model_dir) + output = self.model(self.input_photo, self.input_cartoon, + self.input_superpixel) + self.output_cartoon = output['output_cartoon'] + self.g_loss = output['g_loss'] + self.d_loss = output['d_loss'] + + tf.summary.scalar('g_loss', self.g_loss) + tf.summary.scalar('d_loss', self.d_loss) + + self.train_writer = tf.summary.FileWriter(self.params['work_dir'] + + '/train_log') + self.summary_op = tf.summary.merge_all() + + all_vars = tf.trainable_variables() + gene_vars = [var for var in all_vars if 'gene' in var.name] + disc_vars = [var for var in all_vars if 'disc' in var.name] + update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) + with tf.control_dependencies(update_ops): + self.g_optim = tf.train.AdamOptimizer(self.params['adv_train_lr'], beta1=0.5, beta2=0.99) \ + .minimize(self.g_loss, var_list=gene_vars) + self.d_optim = tf.train.AdamOptimizer(self.params['adv_train_lr'], beta1=0.5, beta2=0.99) \ + .minimize(self.d_loss, var_list=disc_vars) + + self.saver = tf.train.Saver(max_to_keep=1000) + with self._session.as_default() as sess: + sess.run(tf.global_variables_initializer()) + if self.params['resume_epoch'] != 0: + logger.info(f'loading model from {self.model_path}') + self.saver.restore( + sess, + osp.join(self.model_path, + 'model-' + str(self.params['resume_epoch']))) + + def _override_params_from_file(self): + + self.params['photo'] = self.cfg['train']['photo'] + self.params['cartoon'] = self.cfg['train']['cartoon'] + self.params['patch_size'] = self.cfg['train']['patch_size'] + self.params['work_dir'] = self.cfg['train']['work_dir'] + self.params['batch_size'] = self.cfg['train']['batch_size'] + self.params['adv_train_lr'] = self.cfg['train']['adv_train_lr'] + self.params['max_steps'] = self.cfg['train']['max_steps'] + self.params['logging_interval'] = self.cfg['train']['logging_interval'] + self.params['ckpt_period_interval'] = self.cfg['train'][ + 'ckpt_period_interval'] + self.params['resume_epoch'] = self.cfg['train']['resume_epoch'] + self.params['num_gpus'] = self.cfg['train']['num_gpus'] + + def train(self, *args, **kwargs): + logger.info('Begin local cartoon translator training') + + photo_ds = tf_data_loader(self.face_photo_list, + self.params['batch_size']) + cartoon_ds = tf_data_loader(self.face_cartoon_list, + self.params['batch_size']) + photo_iterator = photo_ds.make_initializable_iterator() + cartoon_iterator = cartoon_ds.make_initializable_iterator() + photo_next = photo_iterator.get_next() + cartoon_next = cartoon_iterator.get_next() + + device = 'gpu:0' if tf.test.is_gpu_available else 'cpu:0' + with tf.device(device): + + for max_steps in tqdm(range(self.params['max_steps'])): + + self._session.run(photo_iterator.initializer) + self._session.run(cartoon_iterator.initializer) + + photo_batch, cartoon_batch = self._session.run( + [photo_next, cartoon_next]) + + transfer_res = self._session.run( + self.output_cartoon, + feed_dict={self.input_photo: photo_batch}) + + input_superpixel = simple_superpixel(transfer_res, seg_num=200) + g_loss, _ = self._session.run( + [self.g_loss, self.g_optim], + feed_dict={ + self.input_photo: photo_batch, + self.input_superpixel: input_superpixel, + self.input_cartoon: cartoon_batch + }) + + d_loss, _, train_info = self._session.run( + [self.d_loss, self.d_optim, self.summary_op], + feed_dict={ + self.input_photo: photo_batch, + self.input_superpixel: input_superpixel, + self.input_cartoon: cartoon_batch + }) + + self.train_writer.add_summary(train_info, max_steps) + + if np.mod(max_steps + 1, self.params['logging_interval'] + ) == 0 or max_steps == 0: + + logger.info( + f'Iter: {max_steps}, d_loss: {d_loss}, g_loss: {g_loss}' + ) + + if np.mod(max_steps + 1, + self.params['ckpt_period_interval'] + ) == 0 or max_steps == 0: + self.saver.save( + self._session, + self.params['work_dir'] + '/saved_models/model', + write_meta_graph=False, + global_step=max_steps) + + result_face = self._session.run( + self.output_cartoon, + feed_dict={ + self.input_photo: photo_batch, + self.input_superpixel: photo_batch, + self.input_cartoon: cartoon_batch + }) + + write_batch_image( + result_face, self.params['work_dir'] + '/images', + str('%8d' % max_steps) + '_face_result.jpg', 4) + write_batch_image( + photo_batch, self.params['work_dir'] + '/images', + str('%8d' % max_steps) + '_face_photo.jpg', 4) + + def evaluate(self, + checkpoint_path: Optional[str] = None, + *args, + **kwargs) -> Dict[str, float]: + """evaluate a dataset + + evaluate a dataset via a specific model from the `checkpoint_path` path, if the `checkpoint_path` + does not exist, read from the config file. + + Args: + checkpoint_path (Optional[str], optional): the model path. Defaults to None. + + Returns: + Dict[str, float]: the results about the evaluation + Example: + {"accuracy": 0.5091743119266054, "f1": 0.673780487804878} + """ + raise NotImplementedError( + 'evaluate is not supported by CartoonTranslationTrainer') diff --git a/tests/run_config.yaml b/tests/run_config.yaml index 52efb91d..e7466af6 100644 --- a/tests/run_config.yaml +++ b/tests/run_config.yaml @@ -59,6 +59,7 @@ isolated: # test cases that may require excessive anmount of GPU memory or run - test_video_deinterlace.py - test_image_inpainting_sdv2.py - test_bad_image_detecting.py + - test_image_portrait_stylization_trainer.py - test_controllable_image_generation.py envs: @@ -80,3 +81,4 @@ envs: - test_person_image_cartoon.py - test_skin_retouching.py - test_image_style_transfer.py + - test_image_portrait_stylization_trainer.py diff --git a/tests/trainers/test_image_portrait_stylization_trainer.py b/tests/trainers/test_image_portrait_stylization_trainer.py new file mode 100644 index 00000000..6a3c41fa --- /dev/null +++ b/tests/trainers/test_image_portrait_stylization_trainer.py @@ -0,0 +1,61 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import unittest + +import cv2 + +from modelscope.exporters.cv import CartoonTranslationExporter +from modelscope.msdatasets import MsDataset +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.trainers.cv import CartoonTranslationTrainer +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class TestImagePortraitStylizationTrainer(unittest.TestCase): + + def setUp(self) -> None: + self.task = Tasks.image_portrait_stylization + self.test_image = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/image_cartoon.png' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name(self): + model_id = 'damo/cv_unet_person-image-cartoon_compound-models' + + data_dir = MsDataset.load( + 'dctnet_train_clipart_mini_ms', + namespace='menyifang', + split='train').config_kwargs['split_config']['train'] + + data_photo = os.path.join(data_dir, 'face_photo') + data_cartoon = os.path.join(data_dir, 'face_cartoon') + work_dir = 'exp_localtoon' + max_steps = 10 + trainer = CartoonTranslationTrainer( + model=model_id, + work_dir=work_dir, + photo=data_photo, + cartoon=data_cartoon, + max_steps=max_steps) + trainer.train() + + ckpt_path = os.path.join(work_dir, 'saved_models', 'model-' + str(0)) + pb_path = os.path.join(trainer.model_dir, 'cartoon_h.pb') + exporter = CartoonTranslationExporter() + exporter.export_frozen_graph_def( + ckpt_path=ckpt_path, frozen_graph_path=pb_path) + + self.pipeline_person_image_cartoon(trainer.model_dir) + + def pipeline_person_image_cartoon(self, model_dir): + pipeline_cartoon = pipeline(task=self.task, model=model_dir) + result = pipeline_cartoon(input=self.test_image) + if result is not None: + cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG]) + print(f'Output written to {os.path.abspath("result.png")}') + + +if __name__ == '__main__': + unittest.main()