mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 04:01:10 +01:00
[to #43259593]add cv tryon task
增加虚拟试衣任务,输入模特图,骨骼图,衣服展示图,生成试衣效果图
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9401415
This commit is contained in:
3
data/test/images/virtual_tryon_cloth.jpg
Normal file
3
data/test/images/virtual_tryon_cloth.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8ce0d25b3392f140bf35fba9c6711fdcfc2efde536600aa48dace35462e81adf
|
||||
size 8825
|
||||
3
data/test/images/virtual_tryon_model.jpg
Normal file
3
data/test/images/virtual_tryon_model.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:bb76a61306d3d311d440c5c695958909166e04fb34c827d74d766ba830945d6f
|
||||
size 5034
|
||||
3
data/test/images/virtual_tryon_pose.jpg
Normal file
3
data/test/images/virtual_tryon_pose.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:0ab9baf18074b6b5655ee546794789395757486d6e2180c2627aad47b819e505
|
||||
size 11778
|
||||
@@ -60,6 +60,7 @@ class Pipelines(object):
|
||||
action_recognition = 'TAdaConv_action-recognition'
|
||||
animal_recognation = 'resnet101-animal_recog'
|
||||
cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding'
|
||||
virtual_tryon = 'virtual_tryon'
|
||||
image_colorization = 'unet-image-colorization'
|
||||
image_super_resolution = 'rrdb-image-super-resolution'
|
||||
face_image_generation = 'gan-face-image-generation'
|
||||
|
||||
0
modelscope/models/cv/virual_tryon/__init__.py
Normal file
0
modelscope/models/cv/virual_tryon/__init__.py
Normal file
442
modelscope/models/cv/virual_tryon/sdafnet.py
Normal file
442
modelscope/models/cv/virual_tryon/sdafnet.py
Normal file
@@ -0,0 +1,442 @@
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models import MODELS
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
|
||||
|
||||
def apply_offset(offset):
|
||||
sizes = list(offset.size()[2:])
|
||||
grid_list = torch.meshgrid(
|
||||
[torch.arange(size, device=offset.device) for size in sizes])
|
||||
grid_list = reversed(grid_list)
|
||||
# apply offset
|
||||
grid_list = [
|
||||
grid.float().unsqueeze(0) + offset[:, dim, ...]
|
||||
for dim, grid in enumerate(grid_list)
|
||||
]
|
||||
# normalize
|
||||
grid_list = [
|
||||
grid / ((size - 1.0) / 2.0) - 1.0
|
||||
for grid, size in zip(grid_list, reversed(sizes))
|
||||
]
|
||||
|
||||
return torch.stack(grid_list, dim=-1)
|
||||
|
||||
|
||||
# backbone
|
||||
class ResBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_channels):
|
||||
super(ResBlock, self).__init__()
|
||||
self.block = nn.Sequential(
|
||||
nn.BatchNorm2d(in_channels), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3,
|
||||
padding=1, bias=False), nn.BatchNorm2d(in_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, padding=1,
|
||||
bias=False))
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x) + x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(Downsample, self).__init__()
|
||||
self.block = nn.Sequential(
|
||||
nn.BatchNorm2d(in_channels), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False))
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class FeatureEncoder(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, chns=[64, 128, 256, 256, 256]):
|
||||
# in_channels = 3 for images, and is larger (e.g., 17+1+1) for agnositc representation
|
||||
super(FeatureEncoder, self).__init__()
|
||||
self.encoders = []
|
||||
for i, out_chns in enumerate(chns):
|
||||
if i == 0:
|
||||
encoder = nn.Sequential(
|
||||
Downsample(in_channels, out_chns), ResBlock(out_chns),
|
||||
ResBlock(out_chns))
|
||||
else:
|
||||
encoder = nn.Sequential(
|
||||
Downsample(chns[i - 1], out_chns), ResBlock(out_chns),
|
||||
ResBlock(out_chns))
|
||||
|
||||
self.encoders.append(encoder)
|
||||
|
||||
self.encoders = nn.ModuleList(self.encoders)
|
||||
|
||||
def forward(self, x):
|
||||
encoder_features = []
|
||||
for encoder in self.encoders:
|
||||
x = encoder(x)
|
||||
encoder_features.append(x)
|
||||
return encoder_features
|
||||
|
||||
|
||||
class RefinePyramid(nn.Module):
|
||||
|
||||
def __init__(self, chns=[64, 128, 256, 256, 256], fpn_dim=256):
|
||||
super(RefinePyramid, self).__init__()
|
||||
self.chns = chns
|
||||
|
||||
# adaptive
|
||||
self.adaptive = []
|
||||
for in_chns in list(reversed(chns)):
|
||||
adaptive_layer = nn.Conv2d(in_chns, fpn_dim, kernel_size=1)
|
||||
self.adaptive.append(adaptive_layer)
|
||||
self.adaptive = nn.ModuleList(self.adaptive)
|
||||
# output conv
|
||||
self.smooth = []
|
||||
for i in range(len(chns)):
|
||||
smooth_layer = nn.Conv2d(
|
||||
fpn_dim, fpn_dim, kernel_size=3, padding=1)
|
||||
self.smooth.append(smooth_layer)
|
||||
self.smooth = nn.ModuleList(self.smooth)
|
||||
|
||||
def forward(self, x):
|
||||
conv_ftr_list = x
|
||||
|
||||
feature_list = []
|
||||
last_feature = None
|
||||
for i, conv_ftr in enumerate(list(reversed(conv_ftr_list))):
|
||||
# adaptive
|
||||
feature = self.adaptive[i](conv_ftr)
|
||||
# fuse
|
||||
if last_feature is not None:
|
||||
feature = feature + F.interpolate(
|
||||
last_feature, scale_factor=2, mode='nearest')
|
||||
# smooth
|
||||
feature = self.smooth[i](feature)
|
||||
last_feature = feature
|
||||
feature_list.append(feature)
|
||||
|
||||
return tuple(reversed(feature_list))
|
||||
|
||||
|
||||
def DAWarp(feat, offsets, att_maps, sample_k, out_ch):
|
||||
att_maps = torch.repeat_interleave(att_maps, out_ch, 1)
|
||||
B, C, H, W = feat.size()
|
||||
multi_feat = torch.repeat_interleave(feat, sample_k, 0)
|
||||
multi_warp_feat = F.grid_sample(
|
||||
multi_feat,
|
||||
offsets.detach().permute(0, 2, 3, 1),
|
||||
mode='bilinear',
|
||||
padding_mode='border')
|
||||
multi_att_warp_feat = multi_warp_feat.reshape(B, -1, H, W) * att_maps
|
||||
att_warp_feat = sum(torch.split(multi_att_warp_feat, out_ch, 1))
|
||||
return att_warp_feat
|
||||
|
||||
|
||||
class MFEBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
num_filters=[128, 64, 32]):
|
||||
super(MFEBlock, self).__init__()
|
||||
layers = []
|
||||
for i in range(len(num_filters)):
|
||||
if i == 0:
|
||||
layers.append(
|
||||
torch.nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=num_filters[i],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1))
|
||||
else:
|
||||
layers.append(
|
||||
torch.nn.Conv2d(
|
||||
in_channels=num_filters[i - 1],
|
||||
out_channels=num_filters[i],
|
||||
kernel_size=kernel_size,
|
||||
stride=1,
|
||||
padding=kernel_size // 2))
|
||||
layers.append(
|
||||
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1))
|
||||
layers.append(
|
||||
torch.nn.Conv2d(
|
||||
in_channels=num_filters[-1],
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=1,
|
||||
padding=kernel_size // 2))
|
||||
self.layers = torch.nn.Sequential(*layers)
|
||||
|
||||
def forward(self, input):
|
||||
return self.layers(input)
|
||||
|
||||
|
||||
class DAFlowNet(nn.Module):
|
||||
|
||||
def __init__(self, num_pyramid, fpn_dim=256, head_nums=1):
|
||||
super(DAFlowNet, self).__init__()
|
||||
self.Self_MFEs = []
|
||||
|
||||
self.Cross_MFEs = []
|
||||
self.Refine_MFEs = []
|
||||
self.k = head_nums
|
||||
self.out_ch = fpn_dim
|
||||
for i in range(num_pyramid):
|
||||
# self-MFE for model img 2k:flow 1k:att_map
|
||||
Self_MFE_layer = MFEBlock(
|
||||
in_channels=2 * fpn_dim,
|
||||
out_channels=self.k * 3,
|
||||
kernel_size=7)
|
||||
# cross-MFE for cloth img
|
||||
Cross_MFE_layer = MFEBlock(
|
||||
in_channels=2 * fpn_dim, out_channels=self.k * 3)
|
||||
# refine-MFE for cloth and model imgs
|
||||
Refine_MFE_layer = MFEBlock(
|
||||
in_channels=2 * fpn_dim, out_channels=self.k * 6)
|
||||
self.Self_MFEs.append(Self_MFE_layer)
|
||||
self.Cross_MFEs.append(Cross_MFE_layer)
|
||||
self.Refine_MFEs.append(Refine_MFE_layer)
|
||||
|
||||
self.Self_MFEs = nn.ModuleList(self.Self_MFEs)
|
||||
self.Cross_MFEs = nn.ModuleList(self.Cross_MFEs)
|
||||
self.Refine_MFEs = nn.ModuleList(self.Refine_MFEs)
|
||||
|
||||
self.lights_decoder = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(64, out_channels=32, kernel_size=1, stride=1),
|
||||
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
||||
torch.nn.Conv2d(
|
||||
in_channels=32,
|
||||
out_channels=3,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1))
|
||||
self.lights_encoder = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(
|
||||
3, out_channels=32, kernel_size=3, stride=1, padding=1),
|
||||
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
||||
torch.nn.Conv2d(
|
||||
in_channels=32, out_channels=64, kernel_size=1, stride=1))
|
||||
|
||||
def forward(self,
|
||||
source_image,
|
||||
reference_image,
|
||||
source_feats,
|
||||
reference_feats,
|
||||
return_all=False,
|
||||
warp_feature=True,
|
||||
use_light_en_de=True):
|
||||
r"""
|
||||
Args:
|
||||
source_image: cloth rgb image for tryon
|
||||
reference_image: model rgb image for try on
|
||||
source_feats: cloth FPN features
|
||||
reference_feats: model and pose features
|
||||
return_all: bool return all intermediate try-on results in training phase
|
||||
warp_feature: use DAFlow for both features and images
|
||||
use_light_en_de: use shallow encoder and decoder to project the images from RGB to high dimensional space
|
||||
|
||||
"""
|
||||
|
||||
# reference branch inputs model img using self-DAFlow
|
||||
last_multi_self_offsets = None
|
||||
# source branch inputs cloth img using cross-DAFlow
|
||||
last_multi_cross_offsets = None
|
||||
|
||||
if return_all:
|
||||
results_all = []
|
||||
|
||||
for i in range(len(source_feats)):
|
||||
|
||||
feat_source = source_feats[len(source_feats) - 1 - i]
|
||||
feat_ref = reference_feats[len(reference_feats) - 1 - i]
|
||||
B, C, H, W = feat_source.size()
|
||||
|
||||
# Pre-DAWarp for Pyramid feature
|
||||
if last_multi_cross_offsets is not None and warp_feature:
|
||||
att_source_feat = DAWarp(feat_source, last_multi_cross_offsets,
|
||||
cross_att_maps, self.k, self.out_ch)
|
||||
att_reference_feat = DAWarp(feat_ref, last_multi_self_offsets,
|
||||
self_att_maps, self.k, self.out_ch)
|
||||
else:
|
||||
att_source_feat = feat_source
|
||||
att_reference_feat = feat_ref
|
||||
# Cross-MFE
|
||||
input_feat = torch.cat([att_source_feat, feat_ref], 1)
|
||||
offsets_att = self.Cross_MFEs[i](input_feat)
|
||||
cross_att_maps = F.softmax(
|
||||
offsets_att[:, self.k * 2:, :, :], dim=1)
|
||||
offsets = apply_offset(offsets_att[:, :self.k * 2, :, :].reshape(
|
||||
-1, 2, H, W))
|
||||
if last_multi_cross_offsets is not None:
|
||||
offsets = F.grid_sample(
|
||||
last_multi_cross_offsets,
|
||||
offsets,
|
||||
mode='bilinear',
|
||||
padding_mode='border')
|
||||
else:
|
||||
offsets = offsets.permute(0, 3, 1, 2)
|
||||
last_multi_cross_offsets = offsets
|
||||
att_source_feat = DAWarp(feat_source, last_multi_cross_offsets,
|
||||
cross_att_maps, self.k, self.out_ch)
|
||||
|
||||
# Self-MFE
|
||||
input_feat = torch.cat([att_source_feat, att_reference_feat], 1)
|
||||
offsets_att = self.Self_MFEs[i](input_feat)
|
||||
self_att_maps = F.softmax(offsets_att[:, self.k * 2:, :, :], dim=1)
|
||||
offsets = apply_offset(offsets_att[:, :self.k * 2, :, :].reshape(
|
||||
-1, 2, H, W))
|
||||
if last_multi_self_offsets is not None:
|
||||
offsets = F.grid_sample(
|
||||
last_multi_self_offsets,
|
||||
offsets,
|
||||
mode='bilinear',
|
||||
padding_mode='border')
|
||||
else:
|
||||
offsets = offsets.permute(0, 3, 1, 2)
|
||||
last_multi_self_offsets = offsets
|
||||
att_reference_feat = DAWarp(feat_ref, last_multi_self_offsets,
|
||||
self_att_maps, self.k, self.out_ch)
|
||||
|
||||
# Refine-MFE
|
||||
input_feat = torch.cat([att_source_feat, att_reference_feat], 1)
|
||||
offsets_att = self.Refine_MFEs[i](input_feat)
|
||||
att_maps = F.softmax(offsets_att[:, self.k * 4:, :, :], dim=1)
|
||||
cross_offsets = apply_offset(
|
||||
offsets_att[:, :self.k * 2, :, :].reshape(-1, 2, H, W))
|
||||
self_offsets = apply_offset(
|
||||
offsets_att[:,
|
||||
self.k * 2:self.k * 4, :, :].reshape(-1, 2, H, W))
|
||||
last_multi_cross_offsets = F.grid_sample(
|
||||
last_multi_cross_offsets,
|
||||
cross_offsets,
|
||||
mode='bilinear',
|
||||
padding_mode='border')
|
||||
last_multi_self_offsets = F.grid_sample(
|
||||
last_multi_self_offsets,
|
||||
self_offsets,
|
||||
mode='bilinear',
|
||||
padding_mode='border')
|
||||
|
||||
# Upsampling
|
||||
last_multi_cross_offsets = F.interpolate(
|
||||
last_multi_cross_offsets, scale_factor=2, mode='bilinear')
|
||||
last_multi_self_offsets = F.interpolate(
|
||||
last_multi_self_offsets, scale_factor=2, mode='bilinear')
|
||||
self_att_maps = F.interpolate(
|
||||
att_maps[:, :self.k, :, :], scale_factor=2, mode='bilinear')
|
||||
cross_att_maps = F.interpolate(
|
||||
att_maps[:, self.k:, :, :], scale_factor=2, mode='bilinear')
|
||||
|
||||
# Post-DAWarp for source and reference images
|
||||
if return_all:
|
||||
cur_source_image = F.interpolate(
|
||||
source_image, (H * 2, W * 2), mode='bilinear')
|
||||
cur_reference_image = F.interpolate(
|
||||
reference_image, (H * 2, W * 2), mode='bilinear')
|
||||
if use_light_en_de:
|
||||
cur_source_image = self.lights_encoder(cur_source_image)
|
||||
cur_reference_image = self.lights_encoder(
|
||||
cur_reference_image)
|
||||
# the feat dim in light encoder is 64
|
||||
warp_att_source_image = DAWarp(cur_source_image,
|
||||
last_multi_cross_offsets,
|
||||
cross_att_maps, self.k, 64)
|
||||
warp_att_reference_image = DAWarp(cur_reference_image,
|
||||
last_multi_self_offsets,
|
||||
self_att_maps, self.k,
|
||||
64)
|
||||
result_tryon = self.lights_decoder(
|
||||
warp_att_source_image + warp_att_reference_image)
|
||||
else:
|
||||
warp_att_source_image = DAWarp(cur_source_image,
|
||||
last_multi_cross_offsets,
|
||||
cross_att_maps, self.k, 3)
|
||||
warp_att_reference_image = DAWarp(cur_reference_image,
|
||||
last_multi_self_offsets,
|
||||
self_att_maps, self.k, 3)
|
||||
result_tryon = warp_att_source_image + warp_att_reference_image
|
||||
results_all.append(result_tryon)
|
||||
|
||||
last_multi_self_offsets = F.interpolate(
|
||||
last_multi_self_offsets,
|
||||
reference_image.size()[2:],
|
||||
mode='bilinear')
|
||||
last_multi_cross_offsets = F.interpolate(
|
||||
last_multi_cross_offsets, source_image.size()[2:], mode='bilinear')
|
||||
self_att_maps = F.interpolate(
|
||||
self_att_maps, reference_image.size()[2:], mode='bilinear')
|
||||
cross_att_maps = F.interpolate(
|
||||
cross_att_maps, source_image.size()[2:], mode='bilinear')
|
||||
if use_light_en_de:
|
||||
source_image = self.lights_encoder(source_image)
|
||||
reference_image = self.lights_encoder(reference_image)
|
||||
warp_att_source_image = DAWarp(source_image,
|
||||
last_multi_cross_offsets,
|
||||
cross_att_maps, self.k, 64)
|
||||
warp_att_reference_image = DAWarp(reference_image,
|
||||
last_multi_self_offsets,
|
||||
self_att_maps, self.k, 64)
|
||||
result_tryon = self.lights_decoder(warp_att_source_image
|
||||
+ warp_att_reference_image)
|
||||
else:
|
||||
warp_att_source_image = DAWarp(source_image,
|
||||
last_multi_cross_offsets,
|
||||
cross_att_maps, self.k, 3)
|
||||
warp_att_reference_image = DAWarp(reference_image,
|
||||
last_multi_self_offsets,
|
||||
self_att_maps, self.k, 3)
|
||||
result_tryon = warp_att_source_image + warp_att_reference_image
|
||||
|
||||
if return_all:
|
||||
return result_tryon, return_all
|
||||
return result_tryon
|
||||
|
||||
|
||||
class SDAFNet_Tryon(nn.Module):
|
||||
|
||||
def __init__(self, ref_in_channel, source_in_channel=3, head_nums=6):
|
||||
super(SDAFNet_Tryon, self).__init__()
|
||||
num_filters = [64, 128, 256, 256, 256]
|
||||
self.source_features = FeatureEncoder(source_in_channel, num_filters)
|
||||
self.reference_features = FeatureEncoder(ref_in_channel, num_filters)
|
||||
self.source_FPN = RefinePyramid(num_filters)
|
||||
self.reference_FPN = RefinePyramid(num_filters)
|
||||
self.dafnet = DAFlowNet(len(num_filters), head_nums=head_nums)
|
||||
|
||||
def forward(self,
|
||||
ref_input,
|
||||
source_image,
|
||||
ref_image,
|
||||
use_light_en_de=True,
|
||||
return_all=False,
|
||||
warp_feature=True):
|
||||
reference_feats = self.reference_FPN(
|
||||
self.reference_features(ref_input))
|
||||
source_feats = self.source_FPN(self.source_features(source_image))
|
||||
result = self.dafnet(
|
||||
source_image,
|
||||
ref_image,
|
||||
source_feats,
|
||||
reference_feats,
|
||||
use_light_en_de=use_light_en_de,
|
||||
return_all=return_all,
|
||||
warp_feature=warp_feature)
|
||||
return result
|
||||
@@ -285,5 +285,10 @@ TASK_OUTPUTS = {
|
||||
# {
|
||||
# "output_pcm": {"input_label" : np.ndarray with shape [D]}
|
||||
# }
|
||||
Tasks.text_to_speech: [OutputKeys.OUTPUT_PCM]
|
||||
Tasks.text_to_speech: [OutputKeys.OUTPUT_PCM],
|
||||
# virtual_tryon result for a single sample
|
||||
# {
|
||||
# "output_img": np.ndarray with shape [height, width, 3]
|
||||
# }
|
||||
Tasks.virtual_tryon: [OutputKeys.OUTPUT_IMG]
|
||||
}
|
||||
|
||||
@@ -70,6 +70,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
Tasks.text_to_image_synthesis:
|
||||
(Pipelines.text_to_image_synthesis,
|
||||
'damo/cv_imagen_text-to-image-synthesis_tiny'),
|
||||
Tasks.virtual_tryon: (Pipelines.virtual_tryon,
|
||||
'damo/cv_daflow_virtual-tryon_base'),
|
||||
Tasks.image_colorization: (Pipelines.image_colorization,
|
||||
'damo/cv_unet_image-colorization'),
|
||||
Tasks.style_transfer: (Pipelines.style_transfer,
|
||||
|
||||
@@ -6,6 +6,7 @@ try:
|
||||
from .action_recognition_pipeline import ActionRecognitionPipeline
|
||||
from .animal_recog_pipeline import AnimalRecogPipeline
|
||||
from .cmdssl_video_embedding_pipleline import CMDSSLVideoEmbeddingPipeline
|
||||
from .virtual_tryon_pipeline import VirtualTryonPipeline
|
||||
from .image_colorization_pipeline import ImageColorizationPipeline
|
||||
from .image_super_resolution_pipeline import ImageSuperResolutionPipeline
|
||||
from .face_image_generation_pipeline import FaceImageGenerationPipeline
|
||||
|
||||
124
modelscope/pipelines/cv/virtual_tryon_pipeline.py
Normal file
124
modelscope/pipelines/cv/virtual_tryon_pipeline.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os.path as osp
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Generator, List, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.cv.virual_tryon.sdafnet import SDAFNet_Tryon
|
||||
from modelscope.outputs import TASK_OUTPUTS, OutputKeys
|
||||
from modelscope.pipelines.util import is_model, is_official_hub_path
|
||||
from modelscope.preprocessors import load_image
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from ..base import Pipeline
|
||||
from ..builder import PIPELINES
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.virtual_tryon, module_name=Pipelines.virtual_tryon)
|
||||
class VirtualTryonPipeline(Pipeline):
|
||||
|
||||
def __init__(self, model: str, **kwargs):
|
||||
"""
|
||||
use `model` to create a kws pipeline for prediction
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
super().__init__(model=model)
|
||||
self.device = torch.device(
|
||||
'cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
def filter_param(src_params, own_state):
|
||||
copied_keys = []
|
||||
for name, param in src_params.items():
|
||||
if 'module.' == name[0:7]:
|
||||
name = name[7:]
|
||||
if '.module.' not in list(own_state.keys())[0]:
|
||||
name = name.replace('.module.', '.')
|
||||
if (name in own_state) and (own_state[name].shape
|
||||
== param.shape):
|
||||
own_state[name].copy_(param)
|
||||
copied_keys.append(name)
|
||||
|
||||
def load_pretrained(model, src_params):
|
||||
if 'state_dict' in src_params:
|
||||
src_params = src_params['state_dict']
|
||||
own_state = model.state_dict()
|
||||
filter_param(src_params, own_state)
|
||||
model.load_state_dict(own_state)
|
||||
|
||||
self.model = SDAFNet_Tryon(ref_in_channel=6).to(self.device)
|
||||
local_model_dir = model
|
||||
if osp.exists(model):
|
||||
local_model_dir = model
|
||||
else:
|
||||
local_model_dir = snapshot_download(model)
|
||||
self.local_path = local_model_dir
|
||||
src_params = torch.load(
|
||||
osp.join(local_model_dir, ModelFile.TORCH_MODEL_FILE), 'cpu')
|
||||
load_pretrained(self.model, src_params)
|
||||
self.model = self.model.eval()
|
||||
self.size = 192
|
||||
self.test_transforms = transforms.Compose([
|
||||
transforms.Resize(self.size, interpolation=2),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
])
|
||||
|
||||
def preprocess(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if isinstance(input['masked_model'], str):
|
||||
img_agnostic = load_image(input['masked_model'])
|
||||
pose = load_image(input['pose'])
|
||||
cloth_img = load_image(input['cloth'])
|
||||
elif isinstance(input['masked_model'], PIL.Image.Image):
|
||||
img_agnostic = img_agnostic.convert('RGB')
|
||||
pose = pose.convert('RGB')
|
||||
cloth_img = cloth_img.convert('RGB')
|
||||
elif isinstance(input['masked_model'], np.ndarray):
|
||||
if len(input.shape) == 2:
|
||||
img_agnostic = cv2.cvtColor(img_agnostic, cv2.COLOR_GRAY2BGR)
|
||||
pose = cv2.cvtColor(pose, cv2.COLOR_GRAY2BGR)
|
||||
cloth_img = cv2.cvtColor(cloth_img, cv2.COLOR_GRAY2BGR)
|
||||
img_agnostic = Image.fromarray(
|
||||
img_agnostic[:, :, ::-1].astype('uint8')).convert('RGB')
|
||||
pose = Image.fromarray(
|
||||
pose[:, :, ::-1].astype('uint8')).convert('RGB')
|
||||
cloth_img = Image.fromarray(
|
||||
cloth_img[:, :, ::-1].astype('uint8')).convert('RGB')
|
||||
else:
|
||||
raise TypeError(f'input should be either str, PIL.Image,'
|
||||
f' np.array, but got {type(input)}')
|
||||
|
||||
img_agnostic = self.test_transforms(img_agnostic)
|
||||
pose = self.test_transforms(pose)
|
||||
cloth_img = self.test_transforms(cloth_img)
|
||||
inputs = {
|
||||
'masked_model': img_agnostic.unsqueeze(0),
|
||||
'pose': pose.unsqueeze(0),
|
||||
'cloth': cloth_img.unsqueeze(0)
|
||||
}
|
||||
return inputs
|
||||
|
||||
def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
img_agnostic = inputs['masked_model'].to(self.device)
|
||||
pose = inputs['pose'].to(self.device)
|
||||
cloth_img = inputs['cloth'].to(self.device)
|
||||
ref_input = torch.cat((pose, img_agnostic), dim=1)
|
||||
tryon_result = self.model(ref_input, cloth_img, img_agnostic)
|
||||
return {OutputKeys.OUTPUT_IMG: tryon_result}
|
||||
|
||||
def postprocess(self, outputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
tryon_result = outputs[OutputKeys.OUTPUT_IMG].permute(0, 2, 3,
|
||||
1).squeeze(0)
|
||||
tryon_result = tryon_result.add(1.).div(2.).mul(255).data.cpu().numpy()
|
||||
outputs[OutputKeys.OUTPUT_IMG] = tryon_result
|
||||
return outputs
|
||||
@@ -27,6 +27,7 @@ class CVTasks(object):
|
||||
ocr_detection = 'ocr-detection'
|
||||
action_recognition = 'action-recognition'
|
||||
video_embedding = 'video-embedding'
|
||||
virtual_tryon = 'virtual-tryon'
|
||||
image_colorization = 'image-colorization'
|
||||
face_image_generation = 'face-image-generation'
|
||||
image_super_resolution = 'image-super-resolution'
|
||||
|
||||
36
tests/pipelines/test_virtual_tryon.py
Normal file
36
tests/pipelines/test_virtual_tryon.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class VirtualTryonTest(unittest.TestCase):
|
||||
model_id = 'damo/cv_daflow_virtual-tryon_base'
|
||||
input_imgs = {
|
||||
'masked_model': 'data/test/images/virtual_tryon_model.jpg',
|
||||
'pose': 'data/test/images/virtual_tryon_pose.jpg',
|
||||
'cloth': 'data/test/images/virtual_tryon_cloth.jpg'
|
||||
}
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_run_with_model_name(self):
|
||||
pipeline_virtual_tryon = pipeline(
|
||||
task=Tasks.virtual_tryon, model=self.model_id)
|
||||
img = pipeline_virtual_tryon(self.input_imgs)[OutputKeys.OUTPUT_IMG]
|
||||
cv2.imwrite('demo.jpg', img[:, :, ::-1])
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_with_model_name_default_model(self):
|
||||
pipeline_virtual_tryon = pipeline(task=Tasks.virtual_tryon)
|
||||
img = pipeline_virtual_tryon(self.input_imgs)[OutputKeys.OUTPUT_IMG]
|
||||
cv2.imwrite('demo.jpg', img[:, :, ::-1])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user