fix flake8

This commit is contained in:
ly119399
2024-01-03 00:05:04 +08:00
parent 19cb79ae6c
commit ec07a9919a
25 changed files with 3082 additions and 1857 deletions

View File

@@ -1,2 +1,2 @@
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
from . import ldm
from . import ldm

View File

@@ -1,6 +1,7 @@
import pickle
import numpy as np
import cv2
import numpy as np
from skimage.io import imread
@@ -9,16 +10,18 @@ def save_pickle(data, pkl_path):
with open(pkl_path, 'wb') as f:
pickle.dump(data, f)
def read_pickle(pkl_path):
with open(pkl_path, 'rb') as f:
return pickle.load(f)
def draw_epipolar_line(F, img0, img1, pt0, color):
h1,w1=img1.shape[:2]
h1, w1 = img1.shape[:2]
hpt = np.asarray([pt0[0], pt0[1], 1], dtype=np.float32)[:, None]
l = F @ hpt
l = l[:, 0]
a, b, c = l[0], l[1], l[2]
_l = F @ hpt
_l = _l[:, 0]
a, b, c = _l[0], _l[1], _l[2]
pt1 = np.asarray([0, -c / b]).astype(np.int32)
pt2 = np.asarray([w1, (-a * w1 - c) / b]).astype(np.int32)
@@ -26,8 +29,9 @@ def draw_epipolar_line(F, img0, img1, pt0, color):
img1 = cv2.line(img1, tuple(pt1), tuple(pt2), color, 2)
return img0, img1
def draw_epipolar_lines(F, img0, img1,num=20):
img0,img1=img0.copy(),img1.copy()
def draw_epipolar_lines(F, img0, img1, num=20):
img0, img1 = img0.copy(), img1.copy()
h0, w0, _ = img0.shape
h1, w1, _ = img1.shape
@@ -42,117 +46,166 @@ def draw_epipolar_lines(F, img0, img1,num=20):
return img0, img1
def compute_F(K1, K2, Rt0, Rt1=None):
if Rt1 is None:
R, t = Rt0[:,:3], Rt0[:,3:]
R, t = Rt0[:, :3], Rt0[:, 3:]
else:
Rt = compute_dR_dt(Rt0,Rt1)
R, t = Rt[:,:3], Rt[:,3:]
A = K1 @ R.T @ t # [3,1]
C = np.asarray([[0,-A[2,0],A[1,0]],
[A[2,0],0,-A[0,0]],
[-A[1,0],A[0,0],0]])
Rt = compute_dR_dt(Rt0, Rt1)
R, t = Rt[:, :3], Rt[:, 3:]
A = K1 @ R.T @ t # [3,1]
C = np.asarray([[0, -A[2, 0], A[1, 0]], [A[2, 0], 0, -A[0, 0]],
[-A[1, 0], A[0, 0], 0]])
F = (np.linalg.inv(K2)).T @ R @ K1.T @ C
return F
def compute_dR_dt(Rt0, Rt1):
R0, t0 = Rt0[:,:3], Rt0[:,3:]
R1, t1 = Rt1[:,:3], Rt1[:,3:]
R0, t0 = Rt0[:, :3], Rt0[:, 3:]
R1, t1 = Rt1[:, :3], Rt1[:, 3:]
dR = np.dot(R1, R0.T)
dt = t1 - np.dot(dR, t0)
return np.concatenate([dR, dt], -1)
def concat_images(img0,img1,vert=False):
def concat_images(img0, img1, vert=False):
if not vert:
h0,h1=img0.shape[0],img1.shape[0],
if h0<h1: img0=cv2.copyMakeBorder(img0,0,h1-h0,0,0,borderType=cv2.BORDER_CONSTANT,value=0)
if h1<h0: img1=cv2.copyMakeBorder(img1,0,h0-h1,0,0,borderType=cv2.BORDER_CONSTANT,value=0)
h0, h1 = img0.shape[0], img1.shape[0],
if h0 < h1:
img0 = cv2.copyMakeBorder(
img0,
0,
h1 - h0,
0,
0,
borderType=cv2.BORDER_CONSTANT,
value=0)
if h1 < h0:
img1 = cv2.copyMakeBorder(
img1,
0,
h0 - h1,
0,
0,
borderType=cv2.BORDER_CONSTANT,
value=0)
img = np.concatenate([img0, img1], axis=1)
else:
w0,w1=img0.shape[1],img1.shape[1]
if w0<w1: img0=cv2.copyMakeBorder(img0,0,0,0,w1-w0,borderType=cv2.BORDER_CONSTANT,value=0)
if w1<w0: img1=cv2.copyMakeBorder(img1,0,0,0,w0-w1,borderType=cv2.BORDER_CONSTANT,value=0)
w0, w1 = img0.shape[1], img1.shape[1]
if w0 < w1:
img0 = cv2.copyMakeBorder(
img0,
0,
0,
0,
w1 - w0,
borderType=cv2.BORDER_CONSTANT,
value=0)
if w1 < w0:
img1 = cv2.copyMakeBorder(
img1,
0,
0,
0,
w0 - w1,
borderType=cv2.BORDER_CONSTANT,
value=0)
img = np.concatenate([img0, img1], axis=0)
return img
def concat_images_list(*args,vert=False):
if len(args)==1: return args[0]
img_out=args[0]
def concat_images_list(*args, vert=False):
if len(args) == 1:
return args[0]
img_out = args[0]
for img in args[1:]:
img_out=concat_images(img_out,img,vert)
img_out = concat_images(img_out, img, vert)
return img_out
def pose_inverse(pose):
R = pose[:,:3].T
t = - R @ pose[:,3:]
return np.concatenate([R,t],-1)
R = pose[:, :3].T
t = -R @ pose[:, 3:]
return np.concatenate([R, t], -1)
def project_points(pts,RT,K):
pts = np.matmul(pts,RT[:,:3].transpose())+RT[:,3:].transpose()
pts = np.matmul(pts,K.transpose())
dpt = pts[:,2]
mask0 = (np.abs(dpt)<1e-4) & (np.abs(dpt)>0)
if np.sum(mask0)>0: dpt[mask0]=1e-4
mask1=(np.abs(dpt) > -1e-4) & (np.abs(dpt) < 0)
if np.sum(mask1)>0: dpt[mask1]=-1e-4
pts2d = pts[:,:2]/dpt[:,None]
def project_points(pts, RT, K):
pts = np.matmul(pts, RT[:, :3].transpose()) + RT[:, 3:].transpose()
pts = np.matmul(pts, K.transpose())
dpt = pts[:, 2]
mask0 = (np.abs(dpt) < 1e-4) & (np.abs(dpt) > 0)
if np.sum(mask0) > 0:
dpt[mask0] = 1e-4
mask1 = (np.abs(dpt) > -1e-4) & (np.abs(dpt) < 0)
if np.sum(mask1) > 0:
dpt[mask1] = -1e-4
pts2d = pts[:, :2] / dpt[:, None]
return pts2d, dpt
def draw_keypoints(img, kps, colors=None, radius=2):
out_img=img.copy()
out_img = img.copy()
for pi, pt in enumerate(kps):
pt = np.round(pt).astype(np.int32)
if colors is not None:
color=[int(c) for c in colors[pi]]
color = [int(c) for c in colors[pi]]
cv2.circle(out_img, tuple(pt), radius, color, -1)
else:
cv2.circle(out_img, tuple(pt), radius, (0,255,0), -1)
cv2.circle(out_img, tuple(pt), radius, (0, 255, 0), -1)
return out_img
def output_points(fn,pts,colors=None):
def output_points(fn, pts, colors=None):
with open(fn, 'w') as f:
for pi, pt in enumerate(pts):
f.write(f'{pt[0]:.6f} {pt[1]:.6f} {pt[2]:.6f} ')
if colors is not None:
f.write(f'{int(colors[pi,0])} {int(colors[pi,1])} {int(colors[pi,2])}')
f.write(
f'{int(colors[pi, 0])} {int(colors[pi, 1])} {int(colors[pi, 2])}'
)
f.write('\n')
DEPTH_MAX, DEPTH_MIN = 2.4, 0.6
DEPTH_VALID_MAX, DEPTH_VALID_MIN = 2.37, 0.63
def read_depth_objaverse(depth_fn):
depth = imread(depth_fn)
depth = depth.astype(np.float32) / 65535 * (DEPTH_MAX-DEPTH_MIN) + DEPTH_MIN
depth = depth.astype(
np.float32) / 65535 * (DEPTH_MAX - DEPTH_MIN) + DEPTH_MIN
mask = (depth > DEPTH_VALID_MIN) & (depth < DEPTH_VALID_MAX)
return depth, mask
def mask_depth_to_pts(mask,depth,K,rgb=None):
hs,ws=np.nonzero(mask)
depth=depth[hs,ws]
pts=np.asarray([ws,hs,depth],np.float32).transpose()
pts[:,:2]*=pts[:,2:]
def mask_depth_to_pts(mask, depth, K, rgb=None):
hs, ws = np.nonzero(mask)
depth = depth[hs, ws]
pts = np.asarray([ws, hs, depth], np.float32).transpose()
pts[:, :2] *= pts[:, 2:]
if rgb is not None:
return np.dot(pts, np.linalg.inv(K).transpose()), rgb[hs,ws]
return np.dot(pts, np.linalg.inv(K).transpose()), rgb[hs, ws]
else:
return np.dot(pts, np.linalg.inv(K).transpose())
def transform_points_pose(pts, pose):
R, t = pose[:, :3], pose[:, 3]
if len(pts.shape)==1:
return (R @ pts[:,None] + t[:,None])[:,0]
return pts @ R.T + t[None,:]
if len(pts.shape) == 1:
return (R @ pts[:, None] + t[:, None])[:, 0]
return pts @ R.T + t[None, :]
def pose_apply(pose,pts):
def pose_apply(pose, pts):
return transform_points_pose(pts, pose)
def downsample_gaussian_blur(img, ratio):
sigma = (1 / ratio) / 3
# ksize=np.ceil(2*sigma)
ksize = int(np.ceil(((sigma - 0.8) / 0.3 + 1) * 2 + 1))
ksize = ksize + 1 if ksize % 2 == 0 else ksize
img = cv2.GaussianBlur(img, (ksize, ksize), sigma, borderType=cv2.BORDER_REFLECT101)
return img
img = cv2.GaussianBlur(
img, (ksize, ksize), sigma, borderType=cv2.BORDER_REFLECT101)
return img

View File

@@ -1,34 +1,36 @@
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from contextlib import contextmanager
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.model import Encoder, Decoder
from modelscope.models.cv.image_to_3d.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.model import (
Decoder, Encoder)
from modelscope.models.cv.image_to_3d.ldm.modules.distributions.distributions import \
DiagonalGaussianDistribution
from modelscope.models.cv.image_to_3d.ldm.util import instantiate_from_config
class VQModel(pl.LightningModule):
def __init__(self,
ddconfig,
lossconfig,
n_embed,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
batch_resize_range=None,
scheduler_config=None,
lr_g_factor=1.0,
remap=None,
sane_index_shape=False, # tell vector quantizer to return indices as bhw
use_ema=False
):
def __init__(
self,
ddconfig,
lossconfig,
n_embed,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key='image',
colorize_nlabels=None,
monitor=None,
batch_resize_range=None,
scheduler_config=None,
lr_g_factor=1.0,
remap=None,
sane_index_shape=False, # tell vector quantizer to return indices as bhw
use_ema=False):
super().__init__()
self.embed_dim = embed_dim
self.n_embed = n_embed
@@ -36,24 +38,31 @@ class VQModel(pl.LightningModule):
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
remap=remap,
sane_index_shape=sane_index_shape)
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.quantize = VectorQuantizer(
n_embed,
embed_dim,
beta=0.25,
remap=remap,
sane_index_shape=sane_index_shape)
self.quant_conv = torch.nn.Conv2d(ddconfig['z_channels'], embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim,
ddconfig['z_channels'], 1)
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
assert type(colorize_nlabels) == int
self.register_buffer('colorize',
torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
self.batch_resize_range = batch_resize_range
if self.batch_resize_range is not None:
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
print(
f'{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.'
)
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
@@ -66,28 +75,30 @@ class VQModel(pl.LightningModule):
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
print(f"{context}: Switched to EMA weights")
print(f'{context}: Switched to EMA weights')
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
print(f"{context}: Restored training weights")
print(f'{context}: Restored training weights')
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
sd = torch.load(path, map_location='cpu')['state_dict']
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
print('Deleting key {} from state_dict.'.format(k))
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
print(
f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys'
)
if len(missing) > 0:
print(f"Missing Keys: {missing}")
print(f"Unexpected Keys: {unexpected}")
print(f'Missing Keys: {missing}')
print(f'Unexpected Keys: {unexpected}')
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
@@ -115,7 +126,7 @@ class VQModel(pl.LightningModule):
return dec
def forward(self, input, return_pred_indices=False):
quant, diff, (_,_,ind) = self.encode(input)
quant, diff, (_, _, ind) = self.encode(input)
dec = self.decode(quant)
if return_pred_indices:
return dec, diff, ind
@@ -125,7 +136,8 @@ class VQModel(pl.LightningModule):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
x = x.permute(0, 3, 1,
2).to(memory_format=torch.contiguous_format).float()
if self.batch_resize_range is not None:
lower_size = self.batch_resize_range[0]
upper_size = self.batch_resize_range[1]
@@ -133,9 +145,10 @@ class VQModel(pl.LightningModule):
# do the first few batches with max size to avoid later oom
new_resize = upper_size
else:
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
new_resize = np.random.choice(
np.arange(lower_size, upper_size + 16, 16))
if new_resize != x.shape[2]:
x = F.interpolate(x, size=new_resize, mode="bicubic")
x = F.interpolate(x, size=new_resize, mode='bicubic')
x = x.detach()
return x
@@ -147,79 +160,122 @@ class VQModel(pl.LightningModule):
if optimizer_idx == 0:
# autoencode
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train",
predicted_indices=ind)
aeloss, log_dict_ae = self.loss(
qloss,
x,
xrec,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split='train',
predicted_indices=ind)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
self.log_dict(
log_dict_ae,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=True)
return aeloss
if optimizer_idx == 1:
# discriminator
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
discloss, log_dict_disc = self.loss(
qloss,
x,
xrec,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split='train')
self.log_dict(
log_dict_disc,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=True)
return discloss
def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
self._validation_step(batch, batch_idx, suffix='_ema')
return log_dict
def _validation_step(self, batch, batch_idx, suffix=""):
def _validation_step(self, batch, batch_idx, suffix=''):
x = self.get_input(batch, self.image_key)
xrec, qloss, ind = self(x, return_pred_indices=True)
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
self.global_step,
last_layer=self.get_last_layer(),
split="val"+suffix,
predicted_indices=ind
)
aeloss, log_dict_ae = self.loss(
qloss,
x,
xrec,
0,
self.global_step,
last_layer=self.get_last_layer(),
split='val' + suffix,
predicted_indices=ind)
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
self.global_step,
last_layer=self.get_last_layer(),
split="val"+suffix,
predicted_indices=ind
)
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
self.log(f"val{suffix}/rec_loss", rec_loss,
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
self.log(f"val{suffix}/aeloss", aeloss,
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
discloss, log_dict_disc = self.loss(
qloss,
x,
xrec,
1,
self.global_step,
last_layer=self.get_last_layer(),
split='val' + suffix,
predicted_indices=ind)
rec_loss = log_dict_ae[f'val{suffix}/rec_loss']
self.log(
f'val{suffix}/rec_loss',
rec_loss,
prog_bar=True,
logger=True,
on_step=False,
on_epoch=True,
sync_dist=True)
self.log(
f'val{suffix}/aeloss',
aeloss,
prog_bar=True,
logger=True,
on_step=False,
on_epoch=True,
sync_dist=True)
if version.parse(pl.__version__) >= version.parse('1.4.0'):
del log_dict_ae[f"val{suffix}/rec_loss"]
del log_dict_ae[f'val{suffix}/rec_loss']
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr_d = self.learning_rate
lr_g = self.lr_g_factor*self.learning_rate
print("lr_d", lr_d)
print("lr_g", lr_g)
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
list(self.decoder.parameters())+
list(self.quantize.parameters())+
list(self.quant_conv.parameters())+
list(self.post_quant_conv.parameters()),
lr=lr_g, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr_d, betas=(0.5, 0.9))
lr_g = self.lr_g_factor * self.learning_rate
print('lr_d', lr_d)
print('lr_g', lr_g)
opt_ae = torch.optim.Adam(
list(self.encoder.parameters()) + list(self.decoder.parameters())
+ list(self.quantize.parameters())
+ list(self.quant_conv.parameters())
+ list(self.post_quant_conv.parameters()),
lr=lr_g,
betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(
self.loss.discriminator.parameters(), lr=lr_d, betas=(0.5, 0.9))
if self.scheduler_config is not None:
scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
print('Setting up LambdaLR scheduler...')
scheduler = [
{
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
'scheduler':
LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
'interval': 'step',
'frequency': 1
},
{
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
'scheduler':
LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
'interval': 'step',
'frequency': 1
},
@@ -235,7 +291,7 @@ class VQModel(pl.LightningModule):
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if only_inputs:
log["inputs"] = x
log['inputs'] = x
return log
xrec, _ = self(x)
if x.shape[1] > 3:
@@ -243,25 +299,28 @@ class VQModel(pl.LightningModule):
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["inputs"] = x
log["reconstructions"] = xrec
log['inputs'] = x
log['reconstructions'] = xrec
if plot_ema:
with self.ema_scope():
xrec_ema, _ = self(x)
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
log["reconstructions_ema"] = xrec_ema
if x.shape[1] > 3:
xrec_ema = self.to_rgb(xrec_ema)
log['reconstructions_ema'] = xrec_ema
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
assert self.image_key == 'segmentation'
if not hasattr(self, 'colorize'):
self.register_buffer('colorize',
torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
return x
class VQModelInterface(VQModel):
def __init__(self, embed_dim, *args, **kwargs):
super().__init__(embed_dim=embed_dim, *args, **kwargs)
self.embed_dim = embed_dim
@@ -283,43 +342,48 @@ class VQModelInterface(VQModel):
class AutoencoderKL(pl.LightningModule):
def __init__(self,
ddconfig,
lossconfig,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
):
def __init__(
self,
ddconfig,
lossconfig,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key='image',
colorize_nlabels=None,
monitor=None,
):
super().__init__()
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
assert ddconfig['double_z']
self.quant_conv = torch.nn.Conv2d(2 * ddconfig['z_channels'],
2 * embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim,
ddconfig['z_channels'], 1)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
assert type(colorize_nlabels) == int
self.register_buffer('colorize',
torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
sd = torch.load(path, map_location='cpu')['state_dict']
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
print('Deleting key {} from state_dict.'.format(k))
del sd[k]
self.load_state_dict(sd, strict=False)
print(f"Restored from {path}")
print(f'Restored from {path}')
def encode(self, x):
h = self.encoder(x)
@@ -345,7 +409,8 @@ class AutoencoderKL(pl.LightningModule):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
x = x.permute(0, 3, 1,
2).to(memory_format=torch.contiguous_format).float()
return x
def training_step(self, batch, batch_idx, optimizer_idx):
@@ -354,44 +419,91 @@ class AutoencoderKL(pl.LightningModule):
if optimizer_idx == 0:
# train encoder+decoder+logvar
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
aeloss, log_dict_ae = self.loss(
inputs,
reconstructions,
posterior,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split='train')
self.log(
'aeloss',
aeloss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True)
self.log_dict(
log_dict_ae,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=False)
return aeloss
if optimizer_idx == 1:
# train the discriminator
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
discloss, log_dict_disc = self.loss(
inputs,
reconstructions,
posterior,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split='train')
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
self.log(
'discloss',
discloss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True)
self.log_dict(
log_dict_disc,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=False)
return discloss
def validation_step(self, batch, batch_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
last_layer=self.get_last_layer(), split="val")
aeloss, log_dict_ae = self.loss(
inputs,
reconstructions,
posterior,
0,
self.global_step,
last_layer=self.get_last_layer(),
split='val')
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
last_layer=self.get_last_layer(), split="val")
discloss, log_dict_disc = self.loss(
inputs,
reconstructions,
posterior,
1,
self.global_step,
last_layer=self.get_last_layer(),
split='val')
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
self.log('val/rec_loss', log_dict_ae['val/rec_loss'])
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr = self.learning_rate
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
list(self.decoder.parameters())+
list(self.quant_conv.parameters())+
list(self.post_quant_conv.parameters()),
lr=lr, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr, betas=(0.5, 0.9))
opt_ae = torch.optim.Adam(
list(self.encoder.parameters()) + list(self.decoder.parameters())
+ list(self.quant_conv.parameters())
+ list(self.post_quant_conv.parameters()),
lr=lr,
betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(
self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9))
return [opt_ae, opt_disc], []
def get_last_layer(self):
@@ -409,21 +521,23 @@ class AutoencoderKL(pl.LightningModule):
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
log["reconstructions"] = xrec
log["inputs"] = x
log['samples'] = self.decode(torch.randn_like(posterior.sample()))
log['reconstructions'] = xrec
log['inputs'] = x
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
assert self.image_key == 'segmentation'
if not hasattr(self, 'colorize'):
self.register_buffer('colorize',
torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
return x
class IdentityFirstStage(torch.nn.Module):
def __init__(self, *args, vq_interface=False, **kwargs):
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
super().__init__()

View File

@@ -1,17 +1,26 @@
import torch
import torch.nn as nn
from modelscope.models.cv.image_to_3d.ldm.modules.attention import default, zero_module, checkpoint
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.openaimodel import UNetModel
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.util import timestep_embedding
import modelscope.models.cv.image_to_3d.ldm.modules.attention as attention
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.openaimodel import \
UNetModel
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.util import \
timestep_embedding
class DepthAttention(nn.Module):
def __init__(self, query_dim, context_dim, heads, dim_head, output_bias=True):
def __init__(self,
query_dim,
context_dim,
heads,
dim_head,
output_bias=True):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
context_dim = attention.default(context_dim, query_dim)
self.scale = dim_head ** -0.5
self.scale = dim_head**-0.5
self.heads = heads
self.dim_head = dim_head
@@ -34,21 +43,27 @@ class DepthAttention(nn.Module):
b, _, h, w = x.shape
b, _, d, h, w = context.shape
q = self.to_q(x).reshape(b,hn,hd,h,w) # b,t,h,w
k = self.to_k(context).reshape(b,hn,hd,d,h,w) # b,t,d,h,w
v = self.to_v(context).reshape(b,hn,hd,d,h,w) # b,t,d,h,w
q = self.to_q(x).reshape(b, hn, hd, h, w) # b,t,h,w
k = self.to_k(context).reshape(b, hn, hd, d, h, w) # b,t,d,h,w
v = self.to_v(context).reshape(b, hn, hd, d, h, w) # b,t,d,h,w
sim = torch.sum(q.unsqueeze(3) * k, 2) * self.scale # b,hn,d,h,w
sim = torch.sum(q.unsqueeze(3) * k, 2) * self.scale # b,hn,d,h,w
attn = sim.softmax(dim=2)
# b,hn,hd,d,h,w * b,hn,1,d,h,w
out = torch.sum(v * attn.unsqueeze(2), 3) # b,hn,hd,h,w
out = out.reshape(b,hn*hd,h,w)
out = torch.sum(v * attn.unsqueeze(2), 3) # b,hn,hd,h,w
out = out.reshape(b, hn * hd, h, w)
return self.to_out(out)
class DepthTransformer(nn.Module):
def __init__(self, dim, n_heads, d_head, context_dim=None, checkpoint=True):
def __init__(self,
dim,
n_heads,
d_head,
context_dim=None,
checkpoint=True):
super().__init__()
inner_dim = n_heads * d_head
self.proj_in = nn.Sequential(
@@ -57,23 +72,33 @@ class DepthTransformer(nn.Module):
nn.SiLU(True),
)
self.proj_context = nn.Sequential(
nn.Conv3d(context_dim, context_dim, 1, 1, bias=False), # no bias
nn.Conv3d(context_dim, context_dim, 1, 1, bias=False), # no bias
nn.GroupNorm(8, context_dim),
nn.ReLU(True), # only relu, because we want input is 0, output is 0
nn.ReLU(
True), # only relu, because we want input is 0, output is 0
)
self.depth_attn = DepthAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, context_dim=context_dim, output_bias=False) # is a self-attention if not self.disable_self_attn
self.depth_attn = DepthAttention(
query_dim=inner_dim,
heads=n_heads,
dim_head=d_head,
context_dim=context_dim,
output_bias=False
) # is a self-attention if not self.disable_self_attn
self.proj_out = nn.Sequential(
nn.GroupNorm(8, inner_dim),
nn.ReLU(True),
nn.Conv2d(inner_dim, inner_dim, 3, 1, 1, bias=False),
nn.GroupNorm(8, inner_dim),
nn.ReLU(True),
zero_module(nn.Conv2d(inner_dim, dim, 3, 1, 1, bias=False)),
attention.zero_module(
nn.Conv2d(inner_dim, dim, 3, 1, 1, bias=False)),
)
self.checkpoint = checkpoint
self.checkpoint = attention.checkpoint
def forward(self, x, context=None):
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
return attention.checkpoint(self._forward, (x, context),
self.parameters(), self.checkpoint) # noqa
def _forward(self, x, context):
x_in = x
@@ -85,38 +110,65 @@ class DepthTransformer(nn.Module):
class DepthWiseAttention(UNetModel):
def __init__(self, volume_dims=(5,16,32,64), *args, **kwargs):
def __init__(self, volume_dims=(5, 16, 32, 64), *args, **kwargs):
super().__init__(*args, **kwargs)
# num_heads = 4
model_channels = kwargs['model_channels']
channel_mult = kwargs['channel_mult']
d0,d1,d2,d3 = volume_dims
d0, d1, d2, d3 = volume_dims
# 4
ch = model_channels*channel_mult[2]
self.middle_conditions = DepthTransformer(ch, 4, d3 // 2, context_dim=d3)
ch = model_channels * channel_mult[2]
self.middle_conditions = DepthTransformer(
ch, 4, d3 // 2, context_dim=d3)
self.output_conditions=nn.ModuleList()
self.output_b2c = {3:0,4:1,5:2,6:3,7:4,8:5,9:6,10:7,11:8}
self.output_conditions = nn.ModuleList()
self.output_b2c = {
3: 0,
4: 1,
5: 2,
6: 3,
7: 4,
8: 5,
9: 6,
10: 7,
11: 8
}
# 8
ch = model_channels*channel_mult[2]
self.output_conditions.append(DepthTransformer(ch, 4, d2 // 2, context_dim=d2)) # 0
self.output_conditions.append(DepthTransformer(ch, 4, d2 // 2, context_dim=d2)) # 1
ch = model_channels * channel_mult[2]
self.output_conditions.append(
DepthTransformer(ch, 4, d2 // 2, context_dim=d2)) # 0
self.output_conditions.append(
DepthTransformer(ch, 4, d2 // 2, context_dim=d2)) # 1
# 16
self.output_conditions.append(DepthTransformer(ch, 4, d1 // 2, context_dim=d1)) # 2
ch = model_channels*channel_mult[1]
self.output_conditions.append(DepthTransformer(ch, 4, d1 // 2, context_dim=d1)) # 3
self.output_conditions.append(DepthTransformer(ch, 4, d1 // 2, context_dim=d1)) # 4
self.output_conditions.append(
DepthTransformer(ch, 4, d1 // 2, context_dim=d1)) # 2
ch = model_channels * channel_mult[1]
self.output_conditions.append(
DepthTransformer(ch, 4, d1 // 2, context_dim=d1)) # 3
self.output_conditions.append(
DepthTransformer(ch, 4, d1 // 2, context_dim=d1)) # 4
# 32
self.output_conditions.append(DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 5
ch = model_channels*channel_mult[0]
self.output_conditions.append(DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 6
self.output_conditions.append(DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 7
self.output_conditions.append(DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 8
self.output_conditions.append(
DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 5
ch = model_channels * channel_mult[0]
self.output_conditions.append(
DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 6
self.output_conditions.append(
DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 7
self.output_conditions.append(
DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 8
def forward(self, x, timesteps=None, context=None, source_dict=None, **kwargs):
def forward(self,
x,
timesteps=None,
context=None,
source_dict=None,
**kwargs):
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
t_emb = timestep_embedding(
timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
h = x.type(self.dtype)
@@ -138,5 +190,6 @@ class DepthWiseAttention(UNetModel):
return self.out(h)
def get_trainable_parameters(self):
paras = [para for para in self.middle_conditions.parameters()] + [para for para in self.output_conditions.parameters()]
paras = [para for para in self.middle_conditions.parameters()
] + [para for para in self.output_conditions.parameters()]
return paras

View File

@@ -1,10 +1,12 @@
import torch
import torch.nn as nn
class Image2DResBlockWithTV(nn.Module):
def __init__(self, dim, tdim, vdim):
super().__init__()
norm = lambda c: nn.GroupNorm(8, c)
norm = lambda c: nn.GroupNorm(8, c) # noqa
self.time_embed = nn.Conv2d(tdim, dim, 1, 1)
self.view_embed = nn.Conv2d(vdim, dim, 1, 1)
self.conv = nn.Sequential(
@@ -17,22 +19,28 @@ class Image2DResBlockWithTV(nn.Module):
)
def forward(self, x, t, v):
return x+self.conv(x+self.time_embed(t)+self.view_embed(v))
return x + self.conv(x + self.time_embed(t) + self.view_embed(v))
class NoisyTargetViewEncoder(nn.Module):
def __init__(self, time_embed_dim, viewpoint_dim, run_dim=16, output_dim=8):
def __init__(self,
time_embed_dim,
viewpoint_dim,
run_dim=16,
output_dim=8):
super().__init__()
self.init_conv = nn.Conv2d(4, run_dim, 3, 1, 1)
self.out_conv0 = Image2DResBlockWithTV(run_dim, time_embed_dim, viewpoint_dim)
self.out_conv1 = Image2DResBlockWithTV(run_dim, time_embed_dim, viewpoint_dim)
self.out_conv2 = Image2DResBlockWithTV(run_dim, time_embed_dim, viewpoint_dim)
self.out_conv0 = Image2DResBlockWithTV(run_dim, time_embed_dim,
viewpoint_dim)
self.out_conv1 = Image2DResBlockWithTV(run_dim, time_embed_dim,
viewpoint_dim)
self.out_conv2 = Image2DResBlockWithTV(run_dim, time_embed_dim,
viewpoint_dim)
self.final_out = nn.Sequential(
nn.GroupNorm(8, run_dim),
nn.SiLU(True),
nn.Conv2d(run_dim, output_dim, 3, 1, 1)
)
nn.GroupNorm(8, run_dim), nn.SiLU(True),
nn.Conv2d(run_dim, output_dim, 3, 1, 1))
def forward(self, x, t, v):
B, DT = t.shape
@@ -47,23 +55,33 @@ class NoisyTargetViewEncoder(nn.Module):
x = self.final_out(x)
return x
class SpatialUpTimeBlock(nn.Module):
def __init__(self, x_in_dim, t_in_dim, out_dim):
super().__init__()
norm_act = lambda c: nn.GroupNorm(8, c)
norm_act = lambda c: nn.GroupNorm(8, c) # noqa
self.t_conv = nn.Conv3d(t_in_dim, x_in_dim, 1, 1) # 16
self.norm = norm_act(x_in_dim)
self.silu = nn.SiLU(True)
self.conv = nn.ConvTranspose3d(x_in_dim, out_dim, kernel_size=3, padding=1, output_padding=1, stride=2)
self.conv = nn.ConvTranspose3d(
x_in_dim,
out_dim,
kernel_size=3,
padding=1,
output_padding=1,
stride=2)
def forward(self, x, t):
x = x + self.t_conv(t)
return self.conv(self.silu(self.norm(x)))
class SpatialTimeBlock(nn.Module):
def __init__(self, x_in_dim, t_in_dim, out_dim, stride):
super().__init__()
norm_act = lambda c: nn.GroupNorm(8, c)
norm_act = lambda c: nn.GroupNorm(8, c) # noqa
self.t_conv = nn.Conv3d(t_in_dim, x_in_dim, 1, 1) # 16
self.bn = norm_act(x_in_dim)
self.silu = nn.SiLU(True)
@@ -73,61 +91,65 @@ class SpatialTimeBlock(nn.Module):
x = x + self.t_conv(t)
return self.conv(self.silu(self.bn(x)))
class SpatialTime3DNet(nn.Module):
def __init__(self, time_dim=256, input_dim=128, dims=(32, 64, 128, 256)):
super().__init__()
d0, d1, d2, d3 = dims
dt = time_dim
self.init_conv = nn.Conv3d(input_dim, d0, 3, 1, 1) # 32
self.conv0 = SpatialTimeBlock(d0, dt, d0, stride=1)
def __init__(self, time_dim=256, input_dim=128, dims=(32, 64, 128, 256)):
super().__init__()
d0, d1, d2, d3 = dims
dt = time_dim
self.conv1 = SpatialTimeBlock(d0, dt, d1, stride=2)
self.conv2_0 = SpatialTimeBlock(d1, dt, d1, stride=1)
self.conv2_1 = SpatialTimeBlock(d1, dt, d1, stride=1)
self.init_conv = nn.Conv3d(input_dim, d0, 3, 1, 1) # 32
self.conv0 = SpatialTimeBlock(d0, dt, d0, stride=1)
self.conv3 = SpatialTimeBlock(d1, dt, d2, stride=2)
self.conv4_0 = SpatialTimeBlock(d2, dt, d2, stride=1)
self.conv4_1 = SpatialTimeBlock(d2, dt, d2, stride=1)
self.conv1 = SpatialTimeBlock(d0, dt, d1, stride=2)
self.conv2_0 = SpatialTimeBlock(d1, dt, d1, stride=1)
self.conv2_1 = SpatialTimeBlock(d1, dt, d1, stride=1)
self.conv5 = SpatialTimeBlock(d2, dt, d3, stride=2)
self.conv6_0 = SpatialTimeBlock(d3, dt, d3, stride=1)
self.conv6_1 = SpatialTimeBlock(d3, dt, d3, stride=1)
self.conv3 = SpatialTimeBlock(d1, dt, d2, stride=2)
self.conv4_0 = SpatialTimeBlock(d2, dt, d2, stride=1)
self.conv4_1 = SpatialTimeBlock(d2, dt, d2, stride=1)
self.conv7 = SpatialUpTimeBlock(d3, dt, d2)
self.conv8 = SpatialUpTimeBlock(d2, dt, d1)
self.conv9 = SpatialUpTimeBlock(d1, dt, d0)
self.conv5 = SpatialTimeBlock(d2, dt, d3, stride=2)
self.conv6_0 = SpatialTimeBlock(d3, dt, d3, stride=1)
self.conv6_1 = SpatialTimeBlock(d3, dt, d3, stride=1)
def forward(self, x, t):
B, C = t.shape
t = t.view(B, C, 1, 1, 1)
self.conv7 = SpatialUpTimeBlock(d3, dt, d2)
self.conv8 = SpatialUpTimeBlock(d2, dt, d1)
self.conv9 = SpatialUpTimeBlock(d1, dt, d0)
x = self.init_conv(x)
conv0 = self.conv0(x, t)
def forward(self, x, t):
B, C = t.shape
t = t.view(B, C, 1, 1, 1)
x = self.conv1(conv0, t)
x = self.conv2_0(x, t)
conv2 = self.conv2_1(x, t)
x = self.init_conv(x)
conv0 = self.conv0(x, t)
x = self.conv3(conv2, t)
x = self.conv4_0(x, t)
conv4 = self.conv4_1(x, t)
x = self.conv1(conv0, t)
x = self.conv2_0(x, t)
conv2 = self.conv2_1(x, t)
x = self.conv5(conv4, t)
x = self.conv6_0(x, t)
x = self.conv6_1(x, t)
x = self.conv3(conv2, t)
x = self.conv4_0(x, t)
conv4 = self.conv4_1(x, t)
x = self.conv5(conv4, t)
x = self.conv6_0(x, t)
x = self.conv6_1(x, t)
x = conv4 + self.conv7(x, t)
x = conv2 + self.conv8(x, t)
x = conv0 + self.conv9(x, t)
return x
x = conv4 + self.conv7(x, t)
x = conv2 + self.conv8(x, t)
x = conv0 + self.conv9(x, t)
return x
class FrustumTVBlock(nn.Module):
def __init__(self, x_dim, t_dim, v_dim, out_dim, stride):
super().__init__()
norm_act = lambda c: nn.GroupNorm(8, c)
self.t_conv = nn.Conv3d(t_dim, x_dim, 1, 1) # 16
self.v_conv = nn.Conv3d(v_dim, x_dim, 1, 1) # 16
norm_act = lambda c: nn.GroupNorm(8, c) # noqa
self.t_conv = nn.Conv3d(t_dim, x_dim, 1, 1) # 16
self.v_conv = nn.Conv3d(v_dim, x_dim, 1, 1) # 16
self.bn = norm_act(x_dim)
self.silu = nn.SiLU(True)
self.conv = nn.Conv3d(x_dim, out_dim, 3, stride=stride, padding=1)
@@ -136,24 +158,34 @@ class FrustumTVBlock(nn.Module):
x = x + self.t_conv(t) + self.v_conv(v)
return self.conv(self.silu(self.bn(x)))
class FrustumTVUpBlock(nn.Module):
def __init__(self, x_dim, t_dim, v_dim, out_dim):
super().__init__()
norm_act = lambda c: nn.GroupNorm(8, c)
self.t_conv = nn.Conv3d(t_dim, x_dim, 1, 1) # 16
self.v_conv = nn.Conv3d(v_dim, x_dim, 1, 1) # 16
norm_act = lambda c: nn.GroupNorm(8, c) # noqa
self.t_conv = nn.Conv3d(t_dim, x_dim, 1, 1) # 16
self.v_conv = nn.Conv3d(v_dim, x_dim, 1, 1) # 16
self.norm = norm_act(x_dim)
self.silu = nn.SiLU(True)
self.conv = nn.ConvTranspose3d(x_dim, out_dim, kernel_size=3, padding=1, output_padding=1, stride=2)
self.conv = nn.ConvTranspose3d(
x_dim,
out_dim,
kernel_size=3,
padding=1,
output_padding=1,
stride=2)
def forward(self, x, t, v):
x = x + self.t_conv(t) + self.v_conv(v)
return self.conv(self.silu(self.norm(x)))
class FrustumTV3DNet(nn.Module):
def __init__(self, in_dim, t_dim, v_dim, dims=(32, 64, 128, 256)):
super().__init__()
self.conv0 = nn.Conv3d(in_dim, dims[0], 3, 1, 1) # 32
self.conv0 = nn.Conv3d(in_dim, dims[0], 3, 1, 1) # 32
self.conv1 = FrustumTVBlock(dims[0], t_dim, v_dim, dims[1], 2)
self.conv2 = FrustumTVBlock(dims[1], t_dim, v_dim, dims[1], 1)
@@ -169,10 +201,10 @@ class FrustumTV3DNet(nn.Module):
self.up2 = FrustumTVUpBlock(dims[1], t_dim, v_dim, dims[0])
def forward(self, x, t, v):
B,DT = t.shape
t = t.view(B,DT,1,1,1)
B,DV = v.shape
v = v.view(B,DV,1,1,1)
B, DT = t.shape
t = t.view(B, DT, 1, 1, 1)
B, DV = v.shape
v = v.view(B, DV, 1, 1, 1)
b, _, d, h, w = x.shape
x0 = self.conv0(x)
@@ -183,4 +215,4 @@ class FrustumTV3DNet(nn.Module):
x2 = self.up0(x3, t, v) + x2
x1 = self.up1(x2, t, v) + x1
x0 = self.up2(x1, t, v) + x0
return {w: x0, w//2: x1, w//4: x2, w//8: x3}
return {w: x0, w // 2: x1, w // 4: x2, w // 8: x3}

View File

@@ -10,13 +10,13 @@ def project_and_normalize(ref_grid, src_proj, length):
@param length: int
@return: b, n, 2
"""
src_grid = src_proj[:, :3, :3] @ ref_grid + src_proj[:, :3, 3:] # b 3 n
src_grid = src_proj[:, :3, :3] @ ref_grid + src_proj[:, :3, 3:] # b 3 n
div_val = src_grid[:, -1:]
div_val[div_val<1e-4] = 1e-4
src_grid = src_grid[:, :2] / div_val # divide by depth (b, 2, n)
src_grid[:, 0] = src_grid[:, 0]/((length - 1) / 2) - 1 # scale to -1~1
src_grid[:, 1] = src_grid[:, 1]/((length - 1) / 2) - 1 # scale to -1~1
src_grid = src_grid.permute(0, 2, 1) # (b, n, 2)
div_val[div_val < 1e-4] = 1e-4
src_grid = src_grid[:, :2] / div_val # divide by depth (b, 2, n)
src_grid[:, 0] = src_grid[:, 0] / ((length - 1) / 2) - 1 # scale to -1~1
src_grid[:, 1] = src_grid[:, 1] / ((length - 1) / 2) - 1 # scale to -1~1
src_grid = src_grid.permute(0, 2, 1) # (b, n, 2)
return src_grid
@@ -29,38 +29,55 @@ def construct_project_matrix(x_ratio, y_ratio, Ks, poses):
@return:
"""
rfn = Ks.shape[0]
scale_m = torch.tensor([x_ratio, y_ratio, 1.0], dtype=torch.float32, device=Ks.device)
scale_m = torch.tensor([x_ratio, y_ratio, 1.0],
dtype=torch.float32,
device=Ks.device)
scale_m = torch.diag(scale_m)
ref_prj = scale_m[None, :, :] @ Ks @ poses # rfn,3,4
pad_vals = torch.zeros([rfn, 1, 4], dtype=torch.float32, device=ref_prj.device)
pad_vals = torch.zeros([rfn, 1, 4],
dtype=torch.float32,
device=ref_prj.device)
pad_vals[:, :, 3] = 1.0
ref_prj = torch.cat([ref_prj, pad_vals], 1) # rfn,4,4
return ref_prj
def get_warp_coordinates(volume_xyz, warp_size, input_size, Ks, warp_pose):
B, _, D, H, W = volume_xyz.shape
ratio = warp_size / input_size
warp_proj = construct_project_matrix(ratio, ratio, Ks, warp_pose) # B,4,4
warp_coords = project_and_normalize(volume_xyz.view(B,3,D*H*W), warp_proj, warp_size).view(B, D, H, W, 2)
warp_proj = construct_project_matrix(ratio, ratio, Ks, warp_pose) # B,4,4
warp_coords = project_and_normalize(
volume_xyz.view(B, 3, D * H * W), warp_proj,
warp_size).view(B, D, H, W, 2)
return warp_coords
def create_target_volume(depth_size, volume_size, input_image_size, pose_target, K, near=None, far=None):
def create_target_volume(depth_size,
volume_size,
input_image_size,
pose_target,
K,
near=None,
far=None):
device, dtype = pose_target.device, pose_target.dtype
# compute a depth range on the unit sphere
H, W, D, B = volume_size, volume_size, depth_size, pose_target.shape[0]
if near is not None and far is not None :
if near is not None and far is not None:
# near, far b,1,h,w
depth_values = torch.linspace(0, 1, steps=depth_size).to(near.device).to(near.dtype) # d
depth_values = depth_values.view(1, D, 1, 1) # 1,d,1,1
depth_values = depth_values * (far - near) + near # b d h w
depth_values = torch.linspace(
0, 1, steps=depth_size).to(near.device).to(near.dtype) # d
depth_values = depth_values.view(1, D, 1, 1) # 1,d,1,1
depth_values = depth_values * (far - near) + near # b d h w
depth_values = depth_values.view(B, 1, D, H * W)
else:
near, far = near_far_from_unit_sphere_using_camera_poses(pose_target) # b 1
depth_values = torch.linspace(0, 1, steps=depth_size).to(near.device).to(near.dtype) # d
depth_values = depth_values[None,:,None] * (far[:,None,:] - near[:,None,:]) + near[:,None,:] # b d 1
depth_values = depth_values.view(B, 1, D, 1).expand(B, 1, D, H*W)
near, far = near_far_from_unit_sphere_using_camera_poses(
pose_target) # b 1
depth_values = torch.linspace(
0, 1, steps=depth_size).to(near.device).to(near.dtype) # d
depth_values = depth_values[None, :, None] * (
far[:, None, :] - near[:, None, :]) + near[:, None, :] # b d 1
depth_values = depth_values.view(B, 1, D, 1).expand(B, 1, D, H * W)
ratio = volume_size / input_image_size
@@ -68,20 +85,28 @@ def create_target_volume(depth_size, volume_size, input_image_size, pose_target,
# H, W, D, B = volume_size, volume_size, depth_values.shape[1], depth_values.shape[0]
# creat mesh grid: note reference also means target
ref_grid = create_meshgrid(H, W, normalized_coordinates=False) # (1, H, W, 2)
ref_grid = create_meshgrid(
H, W, normalized_coordinates=False) # (1, H, W, 2)
ref_grid = ref_grid.to(device).to(dtype)
ref_grid = ref_grid.permute(0, 3, 1, 2) # (1, 2, H, W)
ref_grid = ref_grid.reshape(1, 2, H*W) # (1, 2, H*W)
ref_grid = ref_grid.expand(B, -1, -1) # (B, 2, H*W)
ref_grid = torch.cat((ref_grid, torch.ones(B, 1, H*W, dtype=ref_grid.dtype, device=ref_grid.device)), dim=1) # (B, 3, H*W)
ref_grid = ref_grid.permute(0, 3, 1, 2) # (1, 2, H, W)
ref_grid = ref_grid.reshape(1, 2, H * W) # (1, 2, H*W)
ref_grid = ref_grid.expand(B, -1, -1) # (B, 2, H*W)
ref_grid = torch.cat(
(ref_grid,
torch.ones(B, 1, H * W, dtype=ref_grid.dtype,
device=ref_grid.device)),
dim=1) # (B, 3, H*W)
ref_grid = ref_grid.unsqueeze(2) * depth_values # (B, 3, D, H*W)
# unproject to space and transfer to world coordinates.
Ks = K
ref_proj = construct_project_matrix(ratio, ratio, Ks, pose_target) # B,4,4
ref_proj_inv = torch.inverse(ref_proj) # B,4,4
ref_grid = ref_proj_inv[:,:3,:3] @ ref_grid.view(B,3,D*H*W) + ref_proj_inv[:,:3,3:] # B,3,3 @ B,3,DHW + B,3,1 => B,3,DHW
return ref_grid.reshape(B,3,D,H,W), depth_values.view(B,1,D,H,W)
ref_proj = construct_project_matrix(ratio, ratio, Ks, pose_target) # B,4,4
ref_proj_inv = torch.inverse(ref_proj) # B,4,4
ref_grid = ref_proj_inv[:, :3, :3] @ ref_grid.view(
B, 3, D * H
* W) + ref_proj_inv[:, :3, 3:] # B,3,3 @ B,3,DHW + B,3,1 => B,3,DHW
return ref_grid.reshape(B, 3, D, H, W), depth_values.view(B, 1, D, H, W)
def near_far_from_unit_sphere_using_camera_poses(camera_poses):
"""
@@ -90,14 +115,16 @@ def near_far_from_unit_sphere_using_camera_poses(camera_poses):
near: b,1
far: b,1
"""
R_w2c = camera_poses[..., :3, :3] # b 3 3
t_w2c = camera_poses[..., :3, 3:] # b 3 1
camera_origin = -R_w2c.permute(0,2,1) @ t_w2c # b 3 1
R_w2c = camera_poses[..., :3, :3] # b 3 3
t_w2c = camera_poses[..., :3, 3:] # b 3 1
camera_origin = -R_w2c.permute(0, 2, 1) @ t_w2c # b 3 1
# R_w2c.T @ (0,0,1) = z_dir
camera_orient = R_w2c.permute(0,2,1)[...,:3,2:3] # b 3 1
camera_origin, camera_orient = camera_origin[...,0], camera_orient[..., 0] # b 3
a = torch.sum(camera_orient ** 2, dim=-1, keepdim=True) # b 1
b = -torch.sum(camera_orient * camera_origin, dim=-1, keepdim=True) # b 1
mid = b / a # b 1
camera_orient = R_w2c.permute(0, 2, 1)[..., :3, 2:3] # b 3 1
camera_origin, camera_orient = camera_origin[...,
0], camera_orient[...,
0] # b 3
a = torch.sum(camera_orient**2, dim=-1, keepdim=True) # b 1
b = -torch.sum(camera_orient * camera_origin, dim=-1, keepdim=True) # b 1
mid = b / a # b 1
near, far = mid - 1.0, mid + 1.0
return near, far
return near, far

View File

@@ -1,11 +1,13 @@
from inspect import isfunction
import math
from inspect import isfunction
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from torch import einsum, nn
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.util import checkpoint
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.util import \
checkpoint
def exists(val):
@@ -13,7 +15,7 @@ def exists(val):
def uniq(arr):
return{el: True for el in arr}.keys()
return {el: True for el in arr}.keys()
def default(val, d):
@@ -35,6 +37,7 @@ def init_(tensor):
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
@@ -42,8 +45,11 @@ class GEGLU(nn.Module):
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
# feedforward
class ConvGEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Conv2d(dim_in, dim_out * 2, 1, 1, 0)
@@ -54,20 +60,16 @@ class ConvGEGLU(nn.Module):
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
project_in = nn.Sequential(nn.Linear(
dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
)
self.net = nn.Sequential(project_in, nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out))
def forward(self, x):
return self.net(x)
@@ -83,54 +85,54 @@ def zero_module(module):
def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
return torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
k = k.softmax(dim=-1)
q, k, v = rearrange(
qkv,
'b (qkv heads c) h w -> qkv b heads c (h w)',
heads=self.heads,
qkv=3)
k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
out = rearrange(
out,
'b heads c (h w) -> b (heads c) h w',
heads=self.heads,
h=h,
w=w)
return self.to_out(out)
class SpatialSelfAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
h_ = x
@@ -140,7 +142,7 @@ class SpatialSelfAttention(nn.Module):
v = self.v(h_)
# compute attention
b,c,h,w = q.shape
b, c, h, w = q.shape
q = rearrange(q, 'b c h w -> b (h w) c')
k = rearrange(k, 'b c h w -> b c (h w)')
w_ = torch.einsum('bij,bjk->bik', q, k)
@@ -155,16 +157,22 @@ class SpatialSelfAttention(nn.Module):
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
h_ = self.proj_out(h_)
return x+h_
return x + h_
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
def __init__(self,
query_dim,
context_dim=None,
heads=8,
dim_head=64,
dropout=0.):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head ** -0.5
self.scale = dim_head**-0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
@@ -172,9 +180,7 @@ class CrossAttention(nn.Module):
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
)
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
def forward(self, x, context=None, mask=None):
h = self.heads
@@ -184,12 +190,13 @@ class CrossAttention(nn.Module):
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
if exists(mask):
mask = mask>0
mask = mask > 0
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
@@ -202,8 +209,15 @@ class CrossAttention(nn.Module):
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)
class BasicSpatialTransformer(nn.Module):
def __init__(self, dim, n_heads, d_head, context_dim=None, checkpoint=True):
def __init__(self,
dim,
n_heads,
d_head,
context_dim=None,
checkpoint=True):
super().__init__()
inner_dim = n_heads * d_head
self.proj_in = nn.Sequential(
@@ -212,7 +226,12 @@ class BasicSpatialTransformer(nn.Module):
nn.GroupNorm(8, inner_dim),
nn.ReLU(True),
)
self.attn = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, context_dim=context_dim) # is a self-attention if not self.disable_self_attn
self.attn = CrossAttention(
query_dim=inner_dim,
heads=n_heads,
dim_head=d_head,
context_dim=context_dim
) # is a self-attention if not self.disable_self_attn
self.out_conv = nn.Sequential(
nn.GroupNorm(8, inner_dim),
nn.ReLU(True),
@@ -221,16 +240,18 @@ class BasicSpatialTransformer(nn.Module):
self.proj_out = nn.Sequential(
nn.GroupNorm(8, inner_dim),
nn.ReLU(True),
zero_module(nn.Conv2d(inner_dim, dim, kernel_size=1, stride=1, padding=0)),
zero_module(
nn.Conv2d(inner_dim, dim, kernel_size=1, stride=1, padding=0)),
)
self.checkpoint = checkpoint
def forward(self, x, context=None):
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
return checkpoint(self._forward, (x, context), self.parameters(),
self.checkpoint)
def _forward(self, x, context):
# input
b,_,h,w = x.shape
b, _, h, w = x.shape
x_in = x
x = self.proj_in(x)
@@ -245,44 +266,64 @@ class BasicSpatialTransformer(nn.Module):
x = self.proj_out(x) + x_in
return x
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, disable_self_attn=False):
def __init__(self,
dim,
n_heads,
d_head,
dropout=0.,
context_dim=None,
gated_ff=True,
checkpoint=True,
disable_self_attn=False):
super().__init__()
self.disable_self_attn = disable_self_attn
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
self.attn1 = CrossAttention(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
context_dim=context_dim if self.disable_self_attn else
None) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
self.attn2 = CrossAttention(
query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward(self, x, context=None):
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
return checkpoint(self._forward, (x, context), self.parameters(),
self.checkpoint)
def _forward(self, x, context=None):
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
x = self.attn1(
self.norm1(x),
context=context if self.disable_self_attn else None) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x
class ConvFeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
nn.Conv2d(dim, inner_dim, 1, 1, 0),
nn.GELU()
) if not glu else ConvGEGLU(dim, inner_dim)
nn.GELU()) if not glu else ConvGEGLU(dim, inner_dim)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
nn.Conv2d(inner_dim, dim_out, 1, 1, 0)
)
self.net = nn.Sequential(project_in, nn.Dropout(dropout),
nn.Conv2d(inner_dim, dim_out, 1, 1, 0))
def forward(self, x):
return self.net(x)
@@ -296,31 +337,36 @@ class SpatialTransformer(nn.Module):
Then apply standard transformer action.
Finally, reshape to image
"""
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None,
def __init__(self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.,
context_dim=None,
disable_self_attn=False):
super().__init__()
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
self.proj_in = nn.Conv2d(in_channels,
inner_dim,
kernel_size=1,
stride=1,
padding=0)
self.proj_in = nn.Conv2d(
in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim,
disable_self_attn=disable_self_attn)
for d in range(depth)]
)
self.transformer_blocks = nn.ModuleList([
BasicTransformerBlock(
inner_dim,
n_heads,
d_head,
dropout=dropout,
context_dim=context_dim,
disable_self_attn=disable_self_attn) for d in range(depth)
])
self.proj_out = zero_module(nn.Conv2d(inner_dim,
in_channels,
kernel_size=1,
stride=1,
padding=0))
self.proj_out = zero_module(
nn.Conv2d(
inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention

View File

@@ -1,6 +1,6 @@
import math
from abc import abstractmethod
from functools import partial
import math
from typing import Iterable
import numpy as np
@@ -8,16 +8,11 @@ import torch as th
import torch.nn as nn
import torch.nn.functional as F
from modelscope.models.cv.image_to_3d.ldm.modules.attention import \
SpatialTransformer
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.util import (
checkpoint,
conv_nd,
linear,
avg_pool_nd,
zero_module,
normalization,
timestep_embedding,
)
from modelscope.models.cv.image_to_3d.ldm.modules.attention import SpatialTransformer
avg_pool_nd, checkpoint, conv_nd, linear, normalization,
timestep_embedding, zero_module)
from modelscope.models.cv.image_to_3d.ldm.util import exists
@@ -25,11 +20,11 @@ from modelscope.models.cv.image_to_3d.ldm.util import exists
def convert_module_to_f16(x):
pass
def convert_module_to_f32(x):
pass
## go
class AttentionPool2d(nn.Module):
"""
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
@@ -43,7 +38,8 @@ class AttentionPool2d(nn.Module):
output_dim: int = None,
):
super().__init__()
self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
self.positional_embedding = nn.Parameter(
th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5)
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
self.num_heads = embed_dim // num_heads_channels
@@ -98,37 +94,46 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
def __init__(self,
channels,
use_conv,
dims=2,
out_channels=None,
padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
self.conv = conv_nd(
dims, self.channels, self.out_channels, 3, padding=padding)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
)
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
mode='nearest')
else:
x = F.interpolate(x, scale_factor=2, mode="nearest")
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.use_conv:
x = self.conv(x)
return x
class TransposedUpsample(nn.Module):
'Learned 2x upsampling without padding'
def __init__(self, channels, out_channels=None, ks=5):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
self.up = nn.ConvTranspose2d(
self.channels, self.out_channels, kernel_size=ks, stride=2)
def forward(self,x):
def forward(self, x):
return self.up(x)
@@ -141,7 +146,12 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
def __init__(self,
channels,
use_conv,
dims=2,
out_channels=None,
padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
@@ -150,8 +160,12 @@ class Downsample(nn.Module):
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
)
dims,
self.channels,
self.out_channels,
3,
stride=stride,
padding=padding)
else:
assert self.channels == self.out_channels
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
@@ -220,7 +234,8 @@ class ResBlock(TimestepBlock):
nn.SiLU(),
linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
2 * self.out_channels
if use_scale_shift_norm else self.out_channels,
),
)
self.out_layers = nn.Sequential(
@@ -228,18 +243,18 @@ class ResBlock(TimestepBlock):
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
),
conv_nd(
dims, self.out_channels, self.out_channels, 3, padding=1)),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(
dims, channels, self.out_channels, 3, padding=1
)
dims, channels, self.out_channels, 3, padding=1)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
self.skip_connection = conv_nd(dims, channels, self.out_channels,
1)
def forward(self, x, emb):
"""
@@ -248,10 +263,8 @@ class ResBlock(TimestepBlock):
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
return checkpoint(
self._forward, (x, emb), self.parameters(), self.use_checkpoint
)
return checkpoint(self._forward, (x, emb), self.parameters(),
self.use_checkpoint)
def _forward(self, x, emb):
if self.updown:
@@ -265,7 +278,7 @@ class ResBlock(TimestepBlock):
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm: # False
if self.use_scale_shift_norm: # False
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = th.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
@@ -298,7 +311,7 @@ class AttentionBlock(nn.Module):
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
), f'q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}'
self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint
self.norm = normalization(channels)
@@ -313,8 +326,10 @@ class AttentionBlock(nn.Module):
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, x):
return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
#return pt_checkpoint(self._forward, x) # pytorch
return checkpoint(
self._forward, (x, ), self.parameters(), True
) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
# return pt_checkpoint(self._forward, x) # pytorch
def _forward(self, x):
b, c, *spatial = x.shape
@@ -341,7 +356,7 @@ def count_flops_attn(model, _x, y):
# We perform two matmuls with the same number of ops.
# The first computes the weight matrix, the second computes
# the combination of the value vectors.
matmul_ops = 2 * b * (num_spatial ** 2) * c
matmul_ops = 2 * b * (num_spatial**2) * c
model.total_ops += th.DoubleTensor([matmul_ops])
@@ -363,13 +378,14 @@ class QKVAttentionLegacy(nn.Module):
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(
ch, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum(
"bct,bcs->bts", q * scale, k * scale
) # More stable with f16 than dividing afterwards
'bct,bcs->bts', q * scale,
k * scale) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v)
a = th.einsum('bts,bcs->bct', weight, v)
return a.reshape(bs, -1, length)
@staticmethod
@@ -398,12 +414,13 @@ class QKVAttention(nn.Module):
q, k, v = qkv.chunk(3, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum(
"bct,bcs->bts",
'bct,bcs->bts',
(q * scale).view(bs * self.n_heads, ch, length),
(k * scale).view(bs * self.n_heads, ch, length),
) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
a = th.einsum('bts,bcs->bct', weight,
v.reshape(bs * self.n_heads, ch, length))
return a.reshape(bs, -1, length)
@staticmethod
@@ -442,40 +459,43 @@ class UNetModel(nn.Module):
"""
def __init__(
self,
image_size,
in_channels,
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
num_classes=None,
use_checkpoint=False,
use_fp16=False,
num_heads=-1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
use_spatial_transformer=False, # custom transformer support
transformer_depth=1, # custom transformer support
context_dim=None, # custom transformer support
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
legacy=True,
disable_self_attentions=None,
num_attention_blocks=None
):
self,
image_size,
in_channels,
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
num_classes=None,
use_checkpoint=False,
use_fp16=False,
num_heads=-1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
use_spatial_transformer=False, # custom transformer support
transformer_depth=1, # custom transformer support
context_dim=None, # custom transformer support
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
legacy=True,
disable_self_attentions=None,
num_attention_blocks=None):
super().__init__()
if use_spatial_transformer:
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
assert context_dim is not None, (
'Fool!! You forgot to include the dimension '
'of your cross-attention conditioning...')
if context_dim is not None:
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
assert use_spatial_transformer, (
'Fool!! You forgot to use the spatial transformer '
'for your cross-attention conditioning...')
from omegaconf.listconfig import ListConfig
if type(context_dim) == ListConfig:
context_dim = list(context_dim)
@@ -497,20 +517,28 @@ class UNetModel(nn.Module):
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
else:
if len(num_res_blocks) != len(channel_mult):
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
"as a list/tuple (per-level) with the same length as channel_mult")
raise ValueError(
'provide num_res_blocks either as an int (globally constant) or '
'as a list/tuple (per-level) with the same length as channel_mult'
)
self.num_res_blocks = num_res_blocks
#self.num_res_blocks = num_res_blocks
# self.num_res_blocks = num_res_blocks
if disable_self_attentions is not None:
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
assert len(disable_self_attentions) == len(channel_mult)
if num_attention_blocks is not None:
assert len(num_attention_blocks) == len(self.num_res_blocks)
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
f"attention will still not be set.") # todo: convert to warning
assert all(
map(
lambda i: self.num_res_blocks[i] >= num_attention_blocks[i
],
range(len(num_attention_blocks))))
print(
f'Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. '
f'This option has LESS priority than attention_resolutions {attention_resolutions}, '
f'i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, '
f'attention will still not be set.'
) # todo: convert to warning
self.attention_resolutions = attention_resolutions
self.dropout = dropout
@@ -534,13 +562,10 @@ class UNetModel(nn.Module):
if self.num_classes is not None:
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
)
]
) # 0
self.input_blocks = nn.ModuleList([
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1))
]) # 0
self._feature_size = model_channels
input_block_chans = [model_channels]
ch = model_channels
@@ -559,21 +584,22 @@ class UNetModel(nn.Module):
)
]
ch = mult * model_channels
if ds in attention_resolutions: # always True
if ds in attention_resolutions: # always True
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
# num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
if exists(disable_self_attentions):
disabled_sa = disable_self_attentions[level]
else:
disabled_sa = False
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
if not exists(num_attention_blocks
) or nr < num_attention_blocks[level]:
layers.append(
AttentionBlock(
ch,
@@ -581,11 +607,14 @@ class UNetModel(nn.Module):
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa
)
)
) if not use_spatial_transformer else
SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
disable_self_attn=disabled_sa))
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
@@ -602,12 +631,8 @@ class UNetModel(nn.Module):
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
)
if resblock_updown
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch
)
)
) if resblock_updown else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch))
)
ch = out_ch
input_block_chans.append(ch)
@@ -620,7 +645,7 @@ class UNetModel(nn.Module):
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
# num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
self.middle_block = TimestepEmbedSequential(
ResBlock(
@@ -637,9 +662,13 @@ class UNetModel(nn.Module):
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
),
) if not use_spatial_transformer else
SpatialTransformer( # always uses a self-attn
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim),
ResBlock(
ch,
time_embed_dim,
@@ -674,14 +703,15 @@ class UNetModel(nn.Module):
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
# num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
if exists(disable_self_attentions):
disabled_sa = disable_self_attentions[level]
else:
disabled_sa = False
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
if not exists(num_attention_blocks
) or i < num_attention_blocks[level]:
layers.append(
AttentionBlock(
ch,
@@ -689,11 +719,14 @@ class UNetModel(nn.Module):
num_heads=num_heads_upsample,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa
)
)
) if not use_spatial_transformer else
SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
disable_self_attn=disabled_sa))
if level and i == self.num_res_blocks[level]:
out_ch = ch
layers.append(
@@ -706,10 +739,8 @@ class UNetModel(nn.Module):
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
up=True,
)
if resblock_updown
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
)
) if resblock_updown else Upsample(
ch, conv_resample, dims=dims, out_channels=out_ch))
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
@@ -717,14 +748,15 @@ class UNetModel(nn.Module):
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
zero_module(
conv_nd(dims, model_channels, out_channels, 3, padding=1)),
)
if self.predict_codebook_ids:
self.id_predictor = nn.Sequential(
normalization(ch),
conv_nd(dims, model_channels, n_embed, 1),
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
)
normalization(ch),
conv_nd(dims, model_channels, n_embed, 1),
# nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
)
def convert_to_fp16(self):
"""
@@ -742,7 +774,7 @@ class UNetModel(nn.Module):
self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32)
def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
@@ -753,18 +785,19 @@ class UNetModel(nn.Module):
"""
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
), 'must specify y if and only if the model is class-conditional'
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) # N
emb = self.time_embed(t_emb) #
t_emb = timestep_embedding(
timesteps, self.model_channels, repeat_only=False) # N
emb = self.time_embed(t_emb) #
if self.num_classes is not None:
assert y.shape == (x.shape[0],)
assert y.shape == (x.shape[0], )
emb = emb + self.label_emb(y)
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb, context) # conv
h = module(h, emb, context) # conv
hs.append(h)
h = self.middle_block(h, emb, context)
for module in self.output_blocks:
@@ -783,30 +816,28 @@ class EncoderUNetModel(nn.Module):
For usage, see UNet.
"""
def __init__(
self,
image_size,
in_channels,
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
use_checkpoint=False,
use_fp16=False,
num_heads=1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
pool="adaptive",
*args,
**kwargs
):
def __init__(self,
image_size,
in_channels,
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
use_checkpoint=False,
use_fp16=False,
num_heads=1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
pool='adaptive',
*args,
**kwargs):
super().__init__()
if num_heads_upsample == -1:
@@ -833,13 +864,10 @@ class EncoderUNetModel(nn.Module):
linear(time_embed_dim, time_embed_dim),
)
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
)
]
)
self.input_blocks = nn.ModuleList([
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1))
])
self._feature_size = model_channels
input_block_chans = [model_channels]
ch = model_channels
@@ -866,8 +894,7 @@ class EncoderUNetModel(nn.Module):
num_heads=num_heads,
num_head_channels=num_head_channels,
use_new_attention_order=use_new_attention_order,
)
)
))
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
@@ -884,12 +911,8 @@ class EncoderUNetModel(nn.Module):
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
)
if resblock_updown
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch
)
)
) if resblock_updown else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch))
)
ch = out_ch
input_block_chans.append(ch)
@@ -923,7 +946,7 @@ class EncoderUNetModel(nn.Module):
)
self._feature_size += ch
self.pool = pool
if pool == "adaptive":
if pool == 'adaptive':
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
@@ -931,22 +954,21 @@ class EncoderUNetModel(nn.Module):
zero_module(conv_nd(dims, ch, out_channels, 1)),
nn.Flatten(),
)
elif pool == "attention":
elif pool == 'attention':
assert num_head_channels != -1
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
AttentionPool2d(
(image_size // ds), ch, num_head_channels, out_channels
),
AttentionPool2d((image_size // ds), ch, num_head_channels,
out_channels),
)
elif pool == "spatial":
elif pool == 'spatial':
self.out = nn.Sequential(
nn.Linear(self._feature_size, 2048),
nn.ReLU(),
nn.Linear(2048, self.out_channels),
)
elif pool == "spatial_v2":
elif pool == 'spatial_v2':
self.out = nn.Sequential(
nn.Linear(self._feature_size, 2048),
normalization(2048),
@@ -954,7 +976,7 @@ class EncoderUNetModel(nn.Module):
nn.Linear(2048, self.out_channels),
)
else:
raise NotImplementedError(f"Unexpected {pool} pooling")
raise NotImplementedError(f'Unexpected {pool} pooling')
def convert_to_fp16(self):
"""
@@ -977,20 +999,20 @@ class EncoderUNetModel(nn.Module):
:param timesteps: a 1-D batch of timesteps.
:return: an [N x K] Tensor of outputs.
"""
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
emb = self.time_embed(
timestep_embedding(timesteps, self.model_channels))
results = []
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb)
if self.pool.startswith("spatial"):
if self.pool.startswith('spatial'):
results.append(h.type(x.dtype).mean(dim=(2, 3)))
h = self.middle_block(h, emb)
if self.pool.startswith("spatial"):
if self.pool.startswith('spatial'):
results.append(h.type(x.dtype).mean(dim=(2, 3)))
h = th.cat(results, axis=-1)
return self.out(h)
else:
h = h.type(x.dtype)
return self.out(h)

View File

@@ -7,50 +7,65 @@
#
# thanks!
import os
import math
import os
import numpy as np
import torch
import torch.nn as nn
import numpy as np
from einops import repeat
from modelscope.models.cv.image_to_3d.ldm.util import instantiate_from_config
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if schedule == "linear":
def make_beta_schedule(schedule,
n_timestep,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3):
if schedule == 'linear':
betas = (
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
)
torch.linspace(
linear_start**0.5,
linear_end**0.5,
n_timestep,
dtype=torch.float64)**2)
elif schedule == "cosine":
elif schedule == 'cosine':
timesteps = (
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
)
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep
+ cosine_s)
alphas = timesteps / (1 + cosine_s) * np.pi / 2
alphas = torch.cos(alphas).pow(2)
alphas = alphas / alphas[0]
betas = 1 - alphas[1:] / alphas[:-1]
betas = np.clip(betas, a_min=0, a_max=0.999)
elif schedule == "sqrt_linear":
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
elif schedule == "sqrt":
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
elif schedule == 'sqrt_linear':
betas = torch.linspace(
linear_start, linear_end, n_timestep, dtype=torch.float64)
elif schedule == 'sqrt':
betas = torch.linspace(
linear_start, linear_end, n_timestep, dtype=torch.float64)**0.5
else:
raise ValueError(f"schedule '{schedule}' unknown.")
return betas.numpy()
def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
def make_ddim_timesteps(ddim_discr_method,
num_ddim_timesteps,
num_ddpm_timesteps,
verbose=True):
if ddim_discr_method == 'uniform':
c = num_ddpm_timesteps // num_ddim_timesteps
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
elif ddim_discr_method == 'quad':
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8),
num_ddim_timesteps))**2).astype(int)
else:
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
raise NotImplementedError(
f'There is no ddim discretization method called "{ddim_discr_method}"'
)
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
# add one to get the final alpha values right (the ones from first scale to data during sampling)
@@ -60,17 +75,27 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep
return steps_out
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
def make_ddim_sampling_parameters(alphacums,
ddim_timesteps,
eta,
verbose=True):
# select alphas for computing the variance schedule
alphas = alphacums[ddim_timesteps]
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
alphas_prev = np.asarray([alphacums[0]]
+ alphacums[ddim_timesteps[:-1]].tolist())
# according the the formula provided in https://arxiv.org/abs/2010.02502
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
# rewrite because of E125
tmp = (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
sigmas = (eta * np.sqrt(tmp))
if verbose:
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
print(f'For the chosen value of eta, which is {eta}, '
f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
print(
f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}'
)
print(
f'For the chosen value of eta, which is {eta}, '
f'this results in the following sigma_t schedule for ddim sampler {sigmas}'
)
return sigmas, alphas, alphas_prev
@@ -96,7 +121,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
return out.reshape(b, *((1, ) * (len(x_shape) - 1)))
def checkpoint(func, inputs, params, flag):
@@ -117,6 +142,7 @@ def checkpoint(func, inputs, params, flag):
class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, length, *args):
ctx.run_function = run_function
@@ -129,7 +155,9 @@ class CheckpointFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, *output_grads):
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
ctx.input_tensors = [
x.detach().requires_grad_(True) for x in ctx.input_tensors
]
with torch.enable_grad():
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
@@ -160,12 +188,14 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=timesteps.device)
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
else:
embedding = repeat(timesteps, 'b -> b d', d=dim)
return embedding
@@ -207,14 +237,17 @@ def normalization(channels):
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class SiLU(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
@@ -225,7 +258,7 @@ def conv_nd(dims, *args, **kwargs):
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
raise ValueError(f'unsupported dimensions: {dims}')
def linear(*args, **kwargs):
@@ -245,7 +278,7 @@ def avg_pool_nd(dims, *args, **kwargs):
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
raise ValueError(f'unsupported dimensions: {dims}')
class HybridConditioner(nn.Module):
@@ -253,7 +286,8 @@ class HybridConditioner(nn.Module):
def __init__(self, c_concat_config, c_crossattn_config):
super().__init__()
self.concat_conditioner = instantiate_from_config(c_concat_config)
self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
self.crossattn_conditioner = instantiate_from_config(
c_crossattn_config)
def forward(self, c_concat, c_crossattn):
c_concat = self.concat_conditioner(c_concat)
@@ -262,6 +296,13 @@ class HybridConditioner(nn.Module):
def noise_like(shape, device, repeat=False):
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
noise = lambda: torch.randn(shape, device=device)
return repeat_noise() if repeat else noise()
def repeat_noise():
return torch.randn((1, *shape[1:]),
device=device).repeat(shape[0],
*((1, ) * (len(shape) - 1)))
def noise():
return torch.randn(shape, device=device)
return repeat_noise() if repeat else noise()

View File

@@ -1,8 +1,9 @@
import torch
import numpy as np
import torch
class AbstractDistribution:
def sample(self):
raise NotImplementedError()
@@ -11,6 +12,7 @@ class AbstractDistribution:
class DiracDistribution(AbstractDistribution):
def __init__(self, value):
self.value = value
@@ -22,6 +24,7 @@ class DiracDistribution(AbstractDistribution):
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
@@ -30,10 +33,12 @@ class DiagonalGaussianDistribution(object):
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
self.var = self.std = torch.zeros_like(
self.mean).to(device=self.parameters.device)
def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
x = self.mean + self.std * torch.randn(
self.mean.shape).to(device=self.parameters.device)
return x
def kl(self, other=None):
@@ -41,21 +46,22 @@ class DiagonalGaussianDistribution(object):
return torch.Tensor([0.])
else:
if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
def nll(self, sample, dims=[1,2,3]):
def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic:
return torch.Tensor([0.])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
logtwopi + self.logvar
+ torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
def mode(self):
@@ -64,7 +70,8 @@ class DiagonalGaussianDistribution(object):
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
source: https://github.com/openai/guided-diffusion/blob/
27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
@@ -74,7 +81,7 @@ def normal_kl(mean1, logvar1, mean2, logvar2):
if isinstance(obj, torch.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
assert tensor is not None, 'at least one argument must be a Tensor'
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp().
@@ -83,10 +90,7 @@ def normal_kl(mean1, logvar1, mean2, logvar2):
for x in (logvar1, logvar2)
]
return 0.5 * (
-1.0
+ logvar2
- logvar1
+ torch.exp(logvar1 - logvar2)
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
)
# rewrite because of W504
tmp = ((mean1 - mean2)**2) * torch.exp(-logvar2)
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + tmp
) # noqa

View File

@@ -2,15 +2,17 @@ import hashlib
import os
import urllib
import warnings
from typing import Any, Union, List
from pkg_resources import packaging
from typing import Any, List, Union
import torch
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from pkg_resources import packaging
from torchvision.transforms import (CenterCrop, Compose, Normalize, Resize,
ToTensor)
from tqdm import tqdm
from modelscope.models.cv.image_to_3d.ldm.modules.encoders.clip.model import build_model
from modelscope.models.cv.image_to_3d.ldm.modules.encoders.clip.model import \
build_model
try:
from torchvision.transforms import InterpolationMode
@@ -18,23 +20,40 @@ try:
except ImportError:
BICUBIC = Image.BICUBIC
if packaging.version.parse(
torch.__version__) < packaging.version.parse('1.7.1'):
warnings.warn('PyTorch version 1.7.1 or higher is recommended')
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
warnings.warn("PyTorch version 1.7.1 or higher is recommended")
__all__ = ["available_models", "load"]
__all__ = ['available_models', 'load']
_MODELS = {
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
"RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
"ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
"ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
'RN50':
'https://openaipublic.azureedge.net/clip/models/'
'afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt',
'RN101':
'https://openaipublic.azureedge.net/clip/models/'
'8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt',
'RN50x4':
'https://openaipublic.azureedge.net/clip/models/'
'7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt',
'RN50x16':
'https://openaipublic.azureedge.net/clip/models/'
'52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt',
'RN50x64':
'https://openaipublic.azureedge.net/clip/models/'
'be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt',
'ViT-B/32':
'https://openaipublic.azureedge.net/clip/models/'
'40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt',
'ViT-B/16':
'https://openaipublic.azureedge.net/clip/models/'
'5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt',
'ViT-L/14':
'https://openaipublic.azureedge.net/clip/models/'
'b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt',
'ViT-L/14@336px':
'https://openaipublic.azureedge.net/clip/models/'
'3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt',
}
@@ -42,20 +61,30 @@ def _download(url: str, root: str):
os.makedirs(root, exist_ok=True)
filename = os.path.basename(url)
expected_sha256 = url.split("/")[-2]
expected_sha256 = url.split('/')[-2]
download_target = os.path.join(root, filename)
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
raise RuntimeError(
f'{download_target} exists and is not a regular file')
if os.path.isfile(download_target):
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
if hashlib.sha256(open(download_target,
'rb').read()).hexdigest() == expected_sha256:
return download_target
else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
warnings.warn(
f'{download_target} exists, but the SHA256 checksum does not match; re-downloading the file'
)
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
with urllib.request.urlopen(url) as source, open(download_target,
'wb') as output:
with tqdm(
total=int(source.info().get('Content-Length')),
ncols=80,
unit='iB',
unit_scale=True,
unit_divisor=1024) as loop:
while True:
buffer = source.read(8192)
if not buffer:
@@ -64,14 +93,17 @@ def _download(url: str, root: str):
output.write(buffer)
loop.update(len(buffer))
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
if hashlib.sha256(open(download_target,
'rb').read()).hexdigest() != expected_sha256:
raise RuntimeError(
'Model has been downloaded but the SHA256 checksum does not not match'
)
return download_target
def _convert_image_to_rgb(image):
return image.convert("RGB")
return image.convert('RGB')
def _transform(n_px):
@@ -80,7 +112,8 @@ def _transform(n_px):
CenterCrop(n_px),
_convert_image_to_rgb,
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711)),
])
@@ -89,7 +122,11 @@ def available_models() -> List[str]:
return list(_MODELS.keys())
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
def load(name: str,
device: Union[str, torch.device] = 'cuda'
if torch.cuda.is_available() else 'cpu',
jit: bool = False,
download_root: str = None):
"""Load a CLIP model
Parameters
@@ -115,37 +152,47 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
"""
if name in _MODELS:
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
model_path = _download(
_MODELS[name], download_root
or os.path.expanduser('~/.cache/clip'))
elif os.path.isfile(name):
model_path = name
else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
raise RuntimeError(
f'Model {name} not found; available models = {available_models()}')
with open(model_path, 'rb') as opened_file:
try:
# loading JIT archive
model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
model = torch.jit.load(
opened_file, map_location=device if jit else 'cpu').eval()
state_dict = None
except RuntimeError:
# loading saved state dict
if jit:
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
warnings.warn(
f'File {model_path} is not a JIT archive. Loading as a state dict instead'
)
jit = False
state_dict = torch.load(opened_file, map_location="cpu")
state_dict = torch.load(opened_file, map_location='cpu')
if not jit:
model = build_model(state_dict or model.state_dict()).to(device)
if str(device) == "cpu":
if str(device) == 'cpu':
model.float()
return model, _transform(model.visual.input_resolution)
# patch the device names
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
device_holder = torch.jit.trace(
lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
device_node = [
n for n in device_holder.graph.findAllNodes('prim::Constant')
if 'Device' in repr(n)
][-1]
def _node_get(node: torch._C.Node, key: str):
"""Gets attributes of a node which is polymorphic over return type.
From https://github.com/pytorch/pytorch/pull/82628
"""
sel = node.kindOf(key)
@@ -153,16 +200,17 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
def patch_device(module):
try:
graphs = [module.graph] if hasattr(module, "graph") else []
graphs = [module.graph] if hasattr(module, 'graph') else []
except RuntimeError:
graphs = []
if hasattr(module, "forward1"):
if hasattr(module, 'forward1'):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes("prim::Constant"):
if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"):
for node in graph.findAllNodes('prim::Constant'):
if 'value' in node.attributeNames() and str(
_node_get(node, 'value')).startswith('cuda'):
node.copyAttributes(device_node)
model.apply(patch_device)
@@ -170,25 +218,28 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
patch_device(model.encode_text)
# patch dtype to float32 on CPU
if str(device) == "cpu":
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
if str(device) == 'cpu':
float_holder = torch.jit.trace(
lambda: torch.ones([]).float(), example_inputs=[])
float_input = list(float_holder.graph.findNode('aten::to').inputs())[1]
float_node = float_input.node()
def patch_float(module):
try:
graphs = [module.graph] if hasattr(module, "graph") else []
graphs = [module.graph] if hasattr(module, 'graph') else []
except RuntimeError:
graphs = []
if hasattr(module, "forward1"):
if hasattr(module, 'forward1'):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes("aten::to"):
for node in graph.findAllNodes('aten::to'):
inputs = list(node.inputs())
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
if _node_get(inputs[i].node(), "value") == 5:
for i in [
1, 2
]: # dtype can be the second or third argument to aten::to()
if _node_get(inputs[i].node(), 'value') == 5:
inputs[i].node().copyAttributes(float_node)
model.apply(patch_float)

View File

@@ -33,11 +33,16 @@ class Bottleneck(nn.Module):
if stride > 1 or inplanes != planes * Bottleneck.expansion:
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
self.downsample = nn.Sequential(OrderedDict([
("-1", nn.AvgPool2d(stride)),
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
("1", nn.BatchNorm2d(planes * self.expansion))
]))
self.downsample = nn.Sequential(
OrderedDict([('-1', nn.AvgPool2d(stride)),
('0',
nn.Conv2d(
inplanes,
planes * self.expansion,
1,
stride=1,
bias=False)),
('1', nn.BatchNorm2d(planes * self.expansion))]))
def forward(self, x: torch.Tensor):
identity = x
@@ -56,9 +61,15 @@ class Bottleneck(nn.Module):
class AttentionPool2d(nn.Module):
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
def __init__(self,
spacial_dim: int,
embed_dim: int,
num_heads: int,
output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
self.positional_embedding = nn.Parameter(
torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
@@ -70,14 +81,17 @@ class AttentionPool2d(nn.Module):
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
x, _ = F.multi_head_attention_forward(
query=x[:1], key=x, value=x,
query=x[:1],
key=x,
value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
in_proj_bias=torch.cat(
[self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
@@ -86,8 +100,7 @@ class AttentionPool2d(nn.Module):
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False
)
need_weights=False)
return x.squeeze(0)
@@ -99,19 +112,27 @@ class ModifiedResNet(nn.Module):
- The final pooling layer is a QKV attention instead of an average pool
"""
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
def __init__(self,
layers,
output_dim,
heads,
input_resolution=224,
width=64):
super().__init__()
self.output_dim = output_dim
self.input_resolution = input_resolution
# the 3-layer stem
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
self.conv1 = nn.Conv2d(
3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(width // 2)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
self.conv2 = nn.Conv2d(
width // 2, width // 2, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(width // 2)
self.relu2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
self.conv3 = nn.Conv2d(
width // 2, width, kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(width)
self.relu3 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(2)
@@ -124,7 +145,8 @@ class ModifiedResNet(nn.Module):
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
embed_dim = width * 32 # the ResNet feature dimension
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim,
heads, output_dim)
def _make_layer(self, planes, blocks, stride=1):
layers = [Bottleneck(self._inplanes, planes, stride)]
@@ -136,6 +158,7 @@ class ModifiedResNet(nn.Module):
return nn.Sequential(*layers)
def forward(self, x):
def stem(x):
x = self.relu1(self.bn1(self.conv1(x)))
x = self.relu2(self.bn2(self.conv2(x)))
@@ -164,27 +187,34 @@ class LayerNorm(nn.LayerNorm):
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class ResidualAttentionBlock(nn.Module):
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
def __init__(self,
d_model: int,
n_head: int,
attn_mask: torch.Tensor = None):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, d_model * 4)),
("gelu", QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model))
]))
self.mlp = nn.Sequential(
OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)),
('gelu', QuickGELU()),
('c_proj', nn.Linear(d_model * 4, d_model))]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
def attention(self, x: torch.Tensor):
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
self.attn_mask = self.attn_mask.to(
dtype=x.dtype,
device=x.device) if self.attn_mask is not None else None
return self.attn(
x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
@@ -193,26 +223,42 @@ class ResidualAttentionBlock(nn.Module):
class Transformer(nn.Module):
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
def __init__(self,
width: int,
layers: int,
heads: int,
attn_mask: torch.Tensor = None):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
self.resblocks = nn.Sequential(*[
ResidualAttentionBlock(width, heads, attn_mask)
for _ in range(layers)
])
def forward(self, x: torch.Tensor):
return self.resblocks(x)
class VisionTransformer(nn.Module):
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
def __init__(self, input_resolution: int, patch_size: int, width: int,
layers: int, heads: int, output_dim: int):
super().__init__()
self.input_resolution = input_resolution
self.output_dim = output_dim
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
self.conv1 = nn.Conv2d(
in_channels=3,
out_channels=width,
kernel_size=patch_size,
stride=patch_size,
bias=False)
scale = width ** -0.5
scale = width**-0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
self.positional_embedding = nn.Parameter(scale * torch.randn(
(input_resolution // patch_size)**2 + 1, width))
self.ln_pre = LayerNorm(width)
self.transformer = Transformer(width, layers, heads)
@@ -222,9 +268,15 @@ class VisionTransformer(nn.Module):
def forward(self, x: torch.Tensor):
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.reshape(x.shape[0], x.shape[1],
-1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
# rewrite because of E126
tmp = self.class_embedding.to(x.dtype) + torch.zeros(
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device) # noqs
x = torch.cat([tmp, x], dim=1)
# shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
x = self.ln_pre(x)
@@ -241,20 +293,21 @@ class VisionTransformer(nn.Module):
class CLIP(nn.Module):
def __init__(self,
embed_dim: int,
# vision
image_resolution: int,
vision_layers: Union[Tuple[int, int, int, int], int],
vision_width: int,
vision_patch_size: int,
# text
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int
):
def __init__(
self,
embed_dim: int,
# vision
image_resolution: int,
vision_layers: Union[Tuple[int, int, int, int], int],
vision_width: int,
vision_patch_size: int,
# text
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int):
super().__init__()
self.context_length = context_length
@@ -266,8 +319,7 @@ class CLIP(nn.Module):
output_dim=embed_dim,
heads=vision_heads,
input_resolution=image_resolution,
width=vision_width
)
width=vision_width)
else:
vision_heads = vision_width // 64
self.visual = VisionTransformer(
@@ -276,22 +328,22 @@ class CLIP(nn.Module):
width=vision_width,
layers=vision_layers,
heads=vision_heads,
output_dim=embed_dim
)
output_dim=embed_dim)
self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask()
)
attn_mask=self.build_attention_mask())
self.vocab_size = vocab_size
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
self.positional_embedding = nn.Parameter(
torch.empty(self.context_length, transformer_width))
self.ln_final = LayerNorm(transformer_width)
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
self.text_projection = nn.Parameter(
torch.empty(transformer_width, embed_dim))
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.initialize_parameters()
@@ -302,20 +354,24 @@ class CLIP(nn.Module):
if isinstance(self.visual, ModifiedResNet):
if self.visual.attnpool is not None:
std = self.visual.attnpool.c_proj.in_features ** -0.5
std = self.visual.attnpool.c_proj.in_features**-0.5
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
for resnet_block in [
self.visual.layer1, self.visual.layer2, self.visual.layer3,
self.visual.layer4
]:
for name, param in resnet_block.named_parameters():
if name.endswith("bn3.weight"):
if name.endswith('bn3.weight'):
nn.init.zeros_(param)
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
attn_std = self.transformer.width ** -0.5
fc_std = (2 * self.transformer.width) ** -0.5
proj_std = (self.transformer.width**-0.5) * (
(2 * self.transformer.layers)**-0.5)
attn_std = self.transformer.width**-0.5
fc_std = (2 * self.transformer.width)**-0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
@@ -323,13 +379,14 @@ class CLIP(nn.Module):
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
if self.text_projection is not None:
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
nn.init.normal_(
self.text_projection, std=self.transformer.width**-0.5)
def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf"))
mask.fill_(float('-inf'))
mask.triu_(1) # zero out the lower diagonal
return mask
@@ -341,7 +398,8 @@ class CLIP(nn.Module):
return self.visual(image.type(self.dtype))
def encode_text(self, text):
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
x = self.token_embedding(text).type(
self.dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
@@ -351,7 +409,8 @@ class CLIP(nn.Module):
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
x = x[torch.arange(x.shape[0]),
text.argmax(dim=-1)] @ self.text_projection
return x
@@ -360,7 +419,8 @@ class CLIP(nn.Module):
text_features = self.encode_text(text)
# normalized features
image_features = image_features / image_features.norm(dim=1, keepdim=True)
image_features = image_features / image_features.norm(
dim=1, keepdim=True)
text_features = text_features / text_features.norm(dim=1, keepdim=True)
# cosine similarity as logits
@@ -375,21 +435,24 @@ class CLIP(nn.Module):
def convert_weights(model: nn.Module):
"""Convert applicable model parameters to fp16"""
def _convert_weights_to_fp16(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
l.weight.data = l.weight.data.half()
if l.bias is not None:
l.bias.data = l.bias.data.half()
def _convert_weights_to_fp16(_l):
if isinstance(_l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
_l.weight.data = _l.weight.data.half()
if _l.bias is not None:
_l.bias.data = _l.bias.data.half()
if isinstance(l, nn.MultiheadAttention):
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
tensor = getattr(l, attr)
if isinstance(_l, nn.MultiheadAttention):
for attr in [
*[f'{s}_proj_weight' for s in ['in', 'q', 'k', 'v']],
'in_proj_bias', 'bias_k', 'bias_v'
]:
tensor = getattr(_l, attr)
if tensor is not None:
tensor.data = tensor.data.half()
for name in ["text_projection", "proj"]:
if hasattr(l, name):
attr = getattr(l, name)
for name in ['text_projection', 'proj']:
if hasattr(_l, name):
attr = getattr(_l, name)
if attr is not None:
attr.data = attr.data.half()
@@ -397,37 +460,51 @@ def convert_weights(model: nn.Module):
def build_model(state_dict: dict):
vit = "visual.proj" in state_dict
vit = 'visual.proj' in state_dict
if vit:
vision_width = state_dict["visual.conv1.weight"].shape[0]
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
vision_width = state_dict['visual.conv1.weight'].shape[0]
vision_layers = len([
k for k in state_dict.keys()
if k.startswith('visual.') and k.endswith('.attn.in_proj_weight')
])
vision_patch_size = state_dict['visual.conv1.weight'].shape[-1]
grid_size = round(
(state_dict['visual.positional_embedding'].shape[0] - 1)**0.5)
image_resolution = vision_patch_size * grid_size
else:
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
counts: list = [
len(
set(
k.split('.')[2] for k in state_dict
if k.startswith(f'visual.layer{b}')))
for b in [1, 2, 3, 4]
]
vision_layers = tuple(counts)
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
vision_width = state_dict['visual.layer1.0.conv1.weight'].shape[0]
output_width = round(
(state_dict['visual.attnpool.positional_embedding'].shape[0]
- 1)**0.5)
vision_patch_size = None
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
assert output_width**2 + 1 == state_dict[
'visual.attnpool.positional_embedding'].shape[0]
image_resolution = output_width * 32
embed_dim = state_dict["text_projection"].shape[1]
context_length = state_dict["positional_embedding"].shape[0]
vocab_size = state_dict["token_embedding.weight"].shape[0]
transformer_width = state_dict["ln_final.weight"].shape[0]
embed_dim = state_dict['text_projection'].shape[1]
context_length = state_dict['positional_embedding'].shape[0]
vocab_size = state_dict['token_embedding.weight'].shape[0]
transformer_width = state_dict['ln_final.weight'].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
transformer_layers = len(
set(
k.split('.')[2] for k in state_dict
if k.startswith('transformer.resblocks')))
model = CLIP(
embed_dim,
image_resolution, vision_layers, vision_width, vision_patch_size,
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
)
model = CLIP(embed_dim, image_resolution, vision_layers, vision_width,
vision_patch_size, context_length, vocab_size,
transformer_width, transformer_heads, transformer_layers)
for key in ["input_resolution", "context_length", "vocab_size"]:
for key in ['input_resolution', 'context_length', 'vocab_size']:
if key in state_dict:
del state_dict[key]

View File

@@ -9,7 +9,9 @@ import regex as re
@lru_cache()
def default_bpe():
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
return os.path.join(
os.path.dirname(os.path.abspath(__file__)),
'bpe_simple_vocab_16e6.txt.gz')
@lru_cache()
@@ -23,13 +25,17 @@ def bytes_to_unicode():
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
bs = list(range(ord('!'),
ord('~') + 1)) + list(range(
ord('¡'),
ord('¬') + 1)) + list(range(ord('®'),
ord('ÿ') + 1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
@@ -60,34 +66,41 @@ def whitespace_clean(text):
class SimpleTokenizer(object):
def __init__(self, bpe_path: str = default_bpe()):
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
merges = merges[1:49152-256-2+1]
merges = gzip.open(bpe_path).read().decode('utf-8').split('\n')
merges = merges[1:49152 - 256 - 2 + 1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v+'</w>' for v in vocab]
vocab = vocab + [v + '</w>' for v in vocab]
for merge in merges:
vocab.append(''.join(merge))
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
self.cache = {
'<|startoftext|>': '<|startoftext|>',
'<|endoftext|>': '<|endoftext|>'
}
self.pat = re.compile(
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
re.IGNORECASE)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
word = tuple(token[:-1]) + (token[-1] + '</w>', )
pairs = get_pairs(word)
if not pairs:
return token+'</w>'
return token + '</w>'
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
bigram = min(
pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
@@ -98,12 +111,13 @@ class SimpleTokenizer(object):
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
except Exception:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
if word[i] == first and i < len(word) - 1 and word[
i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
@@ -122,11 +136,14 @@ class SimpleTokenizer(object):
bpe_tokens = []
text = whitespace_clean(basic_clean(text)).lower()
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
token = ''.join(self.byte_encoder[b]
for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token]
for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
text = bytearray([self.byte_decoder[c] for c in text]).decode(
'utf-8', errors='replace').replace('</w>', ' ')
return text

View File

@@ -1,28 +1,45 @@
import random
from functools import partial
import kornia
import kornia.augmentation as K
import numpy as np
import torch
import torch.nn as nn
import numpy as np
from functools import partial
import kornia
import torch.nn.functional as F
from torchvision import transforms
from transformers import (CLIPTextModel, CLIPTokenizer, CLIPVisionModel,
T5EncoderModel, T5Tokenizer)
from modelscope.models.cv.image_to_3d.ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
from modelscope.models.cv.image_to_3d.ldm.util import default
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.util import (
extract_into_tensor, make_beta_schedule, noise_like)
# import clip
from modelscope.models.cv.image_to_3d.ldm.modules.encoders import clip
# TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
from modelscope.models.cv.image_to_3d.ldm.modules.x_transformer import (
Encoder, TransformerWrapper)
from modelscope.models.cv.image_to_3d.ldm.thirdp.psp.id_loss import IDFeatures
from modelscope.models.cv.image_to_3d.ldm.util import (default,
instantiate_from_config)
class AbstractEncoder(nn.Module):
def __init__(self):
super().__init__()
def encode(self, *args, **kwargs):
raise NotImplementedError
class IdentityEncoder(AbstractEncoder):
def encode(self, x):
return x
class FaceClipEncoder(AbstractEncoder):
def __init__(self, augment=True, retreival_key=None):
super().__init__()
self.encoder = FrozenCLIPImageEmbedder()
@@ -35,16 +52,16 @@ class FaceClipEncoder(AbstractEncoder):
x_offset = 125
if self.retreival_key:
# Assumes retrieved image are packed into the second half of channels
face = img[:,3:,190:440,x_offset:(512-x_offset)]
other = img[:,:3,...].clone()
face = img[:, 3:, 190:440, x_offset:(512 - x_offset)]
other = img[:, :3, ...].clone()
else:
face = img[:,:,190:440,x_offset:(512-x_offset)]
face = img[:, :, 190:440, x_offset:(512 - x_offset)]
other = img.clone()
if self.augment:
face = K.RandomHorizontalFlip()(face)
other[:,:,190:440,x_offset:(512-x_offset)] *= 0
other[:, :, 190:440, x_offset:(512 - x_offset)] *= 0
encodings = [
self.encoder.encode(face),
self.encoder.encode(other),
@@ -55,26 +72,32 @@ class FaceClipEncoder(AbstractEncoder):
def encode(self, img):
if isinstance(img, list):
# Uncondition
return torch.zeros((1, 2, 768), device=self.encoder.model.visual.conv1.weight.device)
return torch.zeros(
(1, 2, 768),
device=self.encoder.model.visual.conv1.weight.device)
return self(img)
class FaceIdClipEncoder(AbstractEncoder):
def __init__(self):
super().__init__()
self.encoder = FrozenCLIPImageEmbedder()
for p in self.encoder.parameters():
p.requires_grad = False
self.id = FrozenFaceEncoder("/home/jpinkney/code/stable-diffusion/model_ir_se50.pth", augment=True)
self.id = FrozenFaceEncoder(
'/home/jpinkney/code/stable-diffusion/model_ir_se50.pth',
augment=True)
def forward(self, img):
encodings = []
with torch.no_grad():
face = kornia.geometry.resize(img, (256, 256),
interpolation='bilinear', align_corners=True)
face = kornia.geometry.resize(
img, (256, 256), interpolation='bilinear', align_corners=True)
other = img.clone()
other[:,:,184:452,122:396] *= 0
other[:, :, 184:452, 122:396] *= 0
encodings = [
self.id.encode(face),
self.encoder.encode(other),
@@ -85,11 +108,15 @@ class FaceIdClipEncoder(AbstractEncoder):
def encode(self, img):
if isinstance(img, list):
# Uncondition
return torch.zeros((1, 2, 768), device=self.encoder.model.visual.conv1.weight.device)
return torch.zeros(
(1, 2, 768),
device=self.encoder.model.visual.conv1.weight.device)
return self(img)
class ClassEmbedder(nn.Module):
def __init__(self, embed_dim, n_classes=1000, key='class'):
super().__init__()
self.key = key
@@ -106,11 +133,19 @@ class ClassEmbedder(nn.Module):
class TransformerEmbedder(AbstractEncoder):
"""Some transformer encoder layers"""
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
def __init__(self,
n_embed,
n_layer,
vocab_size,
max_seq_len=77,
device='cuda'):
super().__init__()
self.device = device
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer))
self.transformer = TransformerWrapper(
num_tokens=vocab_size,
max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer))
def forward(self, tokens):
tokens = tokens.to(self.device) # meh
@@ -123,18 +158,25 @@ class TransformerEmbedder(AbstractEncoder):
class BERTTokenizer(AbstractEncoder):
""" Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
def __init__(self, device="cuda", vq_interface=True, max_length=77):
def __init__(self, device='cuda', vq_interface=True, max_length=77):
super().__init__()
from transformers import BertTokenizerFast # TODO: add to reuquirements
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
self.device = device
self.vq_interface = vq_interface
self.max_length = max_length
def forward(self, text):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding='max_length',
return_tensors='pt')
tokens = batch_encoding['input_ids'].to(self.device)
return tokens
@torch.no_grad()
@@ -150,20 +192,30 @@ class BERTTokenizer(AbstractEncoder):
class BERTEmbedder(AbstractEncoder):
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
device="cuda",use_tokenizer=True, embedding_dropout=0.0):
def __init__(self,
n_embed,
n_layer,
vocab_size=30522,
max_seq_len=77,
device='cuda',
use_tokenizer=True,
embedding_dropout=0.0):
super().__init__()
self.use_tknz_fn = use_tokenizer
if self.use_tknz_fn:
self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
self.tknz_fn = BERTTokenizer(
vq_interface=False, max_length=max_seq_len)
self.device = device
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer),
emb_dropout=embedding_dropout)
self.transformer = TransformerWrapper(
num_tokens=vocab_size,
max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer),
emb_dropout=embedding_dropout)
def forward(self, text):
if self.use_tknz_fn:
tokens = self.tknz_fn(text)#.to(self.device)
tokens = self.tknz_fn(text) # .to(self.device)
else:
tokens = text
z = self.transformer(tokens, return_embeddings=True)
@@ -174,8 +226,6 @@ class BERTEmbedder(AbstractEncoder):
return self(text)
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
@@ -184,24 +234,41 @@ def disabled_train(self, mode=True):
class FrozenT5Embedder(AbstractEncoder):
"""Uses the T5 transformer encoder for text"""
def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
def __init__(self,
version='google/t5-v1_1-large',
device='cuda',
max_length=77
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
super().__init__()
self.tokenizer = T5Tokenizer.from_pretrained(version, cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models')
self.transformer = T5EncoderModel.from_pretrained(version, cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models')
self.tokenizer = T5Tokenizer.from_pretrained(
version,
cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models'
)
self.transformer = T5EncoderModel.from_pretrained(
version,
cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models'
)
self.device = device
self.max_length = max_length # TODO: typical value?
self.max_length = max_length # TODO: typical value?
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
#self.train = disabled_train
# self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding='max_length',
return_tensors='pt')
tokens = batch_encoding['input_ids'].to(self.device)
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
@@ -210,10 +277,9 @@ class FrozenT5Embedder(AbstractEncoder):
def encode(self, text):
return self(text)
from modelscope.models.cv.image_to_3d.ldm.thirdp.psp.id_loss import IDFeatures
import kornia.augmentation as K
class FrozenFaceEncoder(AbstractEncoder):
def __init__(self, model_path, augment=False):
super().__init__()
self.loss_fn = IDFeatures(model_path)
@@ -242,8 +308,8 @@ class FrozenFaceEncoder(AbstractEncoder):
if self.augment is not None:
# Transforms require 0-1
img = self.augment((img + 1)/2)
img = 2*img - 1
img = self.augment((img + 1) / 2)
img = 2 * img - 1
feat = self.loss_fn(img, crop=True)
feat = self.mapper(feat.unsqueeze(1))
@@ -252,26 +318,43 @@ class FrozenFaceEncoder(AbstractEncoder):
def encode(self, img):
return self(img)
class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): # clip-vit-base-patch32
def __init__(self,
version='openai/clip-vit-large-patch14',
device='cuda',
max_length=77): # clip-vit-base-patch32
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version, cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models')
self.transformer = CLIPTextModel.from_pretrained(version, cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models')
self.tokenizer = CLIPTokenizer.from_pretrained(
version,
cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models'
)
self.transformer = CLIPTextModel.from_pretrained(
version,
cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models'
)
self.device = device
self.max_length = max_length # TODO: typical value?
self.max_length = max_length # TODO: typical value?
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
#self.train = disabled_train
# self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding='max_length',
return_tensors='pt')
tokens = batch_encoding['input_ids'].to(self.device)
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
@@ -280,36 +363,47 @@ class FrozenCLIPEmbedder(AbstractEncoder):
def encode(self, text):
return self(text)
import torch.nn.functional as F
from transformers import CLIPVisionModel
class ClipImageProjector(AbstractEncoder):
"""
Uses the CLIP image encoder.
"""
def __init__(self, version="openai/clip-vit-large-patch14", max_length=77): # clip-vit-base-patch32
def __init__(self,
version='openai/clip-vit-large-patch14',
max_length=77): # clip-vit-base-patch32
super().__init__()
self.model = CLIPVisionModel.from_pretrained(version)
self.model.train()
self.max_length = max_length # TODO: typical value?
self.max_length = max_length # TODO: typical value?
self.antialias = True
self.mapper = torch.nn.Linear(1024, 768)
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
self.register_buffer(
'mean',
torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
persistent=False)
self.register_buffer(
'std',
torch.Tensor([0.26862954, 0.26130258, 0.27577711]),
persistent=False)
null_cond = self.get_null_cond(version, max_length)
self.register_buffer('null_cond', null_cond)
@torch.no_grad()
def get_null_cond(self, version, max_length):
device = self.mean.device
embedder = FrozenCLIPEmbedder(version=version, device=device, max_length=max_length)
null_cond = embedder([""])
embedder = FrozenCLIPEmbedder(
version=version, device=device, max_length=max_length)
null_cond = embedder([''])
return null_cond
def preprocess(self, x):
# Expects inputs in the range -1, 1
x = kornia.geometry.resize(x, (224, 224),
interpolation='bicubic',align_corners=True,
antialias=self.antialias)
x = kornia.geometry.resize(
x, (224, 224),
interpolation='bicubic',
align_corners=True,
antialias=self.antialias)
x = (x + 1.) / 2.
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
@@ -323,15 +417,23 @@ class ClipImageProjector(AbstractEncoder):
outputs = self.model(pixel_values=x)
last_hidden_state = outputs.last_hidden_state
last_hidden_state = self.mapper(last_hidden_state)
return F.pad(last_hidden_state, [0,0, 0,self.max_length-last_hidden_state.shape[1], 0,0])
return F.pad(
last_hidden_state,
[0, 0, 0, self.max_length - last_hidden_state.shape[1], 0, 0])
def encode(self, im):
return self(im)
class ProjectedFrozenCLIPEmbedder(AbstractEncoder):
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): # clip-vit-base-patch32
def __init__(self,
version='openai/clip-vit-large-patch14',
device='cuda',
max_length=77): # clip-vit-base-patch32
super().__init__()
self.embedder = FrozenCLIPEmbedder(version=version, device=device, max_length=max_length)
self.embedder = FrozenCLIPEmbedder(
version=version, device=device, max_length=max_length)
self.projection = torch.nn.Linear(768, 768)
def forward(self, text):
@@ -341,31 +443,41 @@ class ProjectedFrozenCLIPEmbedder(AbstractEncoder):
def encode(self, text):
return self(text)
class FrozenCLIPImageEmbedder(AbstractEncoder):
"""
Uses the CLIP image encoder.
Not actually frozen... If you want that set cond_stage_trainable=False in cfg
"""
def __init__(
self,
model='ViT-L/14',
jit=False,
device='cpu',
antialias=False,
):
self,
model='ViT-L/14',
jit=False,
device='cpu',
antialias=False,
):
super().__init__()
self.model, _ = clip.load(name=model, device=device, jit=jit)
# We don't use the text part so delete it
del self.model.transformer
self.antialias = antialias
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
self.register_buffer(
'mean',
torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
persistent=False)
self.register_buffer(
'std',
torch.Tensor([0.26862954, 0.26130258, 0.27577711]),
persistent=False)
def preprocess(self, x):
# Expects inputs in the range -1, 1
x = kornia.geometry.resize(x, (224, 224),
interpolation='bicubic',align_corners=True,
antialias=self.antialias)
x = kornia.geometry.resize(
x, (224, 224),
interpolation='bicubic',
align_corners=True,
antialias=self.antialias)
x = (x + 1.) / 2.
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
@@ -382,35 +494,41 @@ class FrozenCLIPImageEmbedder(AbstractEncoder):
def encode(self, im):
return self(im).unsqueeze(1)
from torchvision import transforms
import random
class FrozenCLIPImageMutliEmbedder(AbstractEncoder):
"""
Uses the CLIP image encoder.
Not actually frozen... If you want that set cond_stage_trainable=False in cfg
"""
def __init__(
self,
model='ViT-L/14',
jit=False,
device='cpu',
antialias=True,
max_crops=5,
):
self,
model='ViT-L/14',
jit=False,
device='cpu',
antialias=True,
max_crops=5,
):
super().__init__()
self.model, _ = clip.load(name=model, device=device, jit=jit)
# We don't use the text part so delete it
del self.model.transformer
self.antialias = antialias
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
self.register_buffer(
'mean',
torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
persistent=False)
self.register_buffer(
'std',
torch.Tensor([0.26862954, 0.26130258, 0.27577711]),
persistent=False)
self.max_crops = max_crops
def preprocess(self, x):
# Expects inputs in the range -1, 1
randcrop = transforms.RandomResizedCrop(224, scale=(0.085, 1.0), ratio=(1,1))
randcrop = transforms.RandomResizedCrop(
224, scale=(0.085, 1.0), ratio=(1, 1))
max_crops = self.max_crops
patches = []
crops = [randcrop(x) for _ in range(max_crops)]
@@ -441,7 +559,9 @@ class FrozenCLIPImageMutliEmbedder(AbstractEncoder):
def encode(self, im):
return self(im)
class SpatialRescaler(nn.Module):
def __init__(self,
n_stages=1,
method='bilinear',
@@ -452,19 +572,24 @@ class SpatialRescaler(nn.Module):
super().__init__()
self.n_stages = n_stages
assert self.n_stages >= 0
assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
assert method in [
'nearest', 'linear', 'bilinear', 'trilinear', 'bicubic', 'area'
]
self.multiplier = multiplier
self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
self.interpolator = partial(
torch.nn.functional.interpolate, mode=method)
self.remap_output = out_channels is not None
if self.remap_output:
print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
print(
f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.'
)
self.channel_mapper = nn.Conv2d(
in_channels, out_channels, 1, bias=bias)
def forward(self,x):
def forward(self, x):
for stage in range(self.n_stages):
x = self.interpolator(x, scale_factor=self.multiplier)
if self.remap_output:
x = self.channel_mapper(x)
return x
@@ -473,25 +598,38 @@ class SpatialRescaler(nn.Module):
return self(x)
from modelscope.models.cv.image_to_3d.ldm.util import instantiate_from_config
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
class LowScaleEncoder(nn.Module):
def __init__(self, model_config, linear_start, linear_end, timesteps=1000, max_noise_level=250, output_size=64,
def __init__(self,
model_config,
linear_start,
linear_end,
timesteps=1000,
max_noise_level=250,
output_size=64,
scale_factor=1.0):
super().__init__()
self.max_noise_level = max_noise_level
self.model = instantiate_from_config(model_config)
self.augmentation_schedule = self.register_schedule(timesteps=timesteps, linear_start=linear_start,
linear_end=linear_end)
self.augmentation_schedule = self.register_schedule(
timesteps=timesteps,
linear_start=linear_start,
linear_end=linear_end)
self.out_size = output_size
self.scale_factor = scale_factor
def register_schedule(self, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
cosine_s=cosine_s)
def register_schedule(self,
beta_schedule='linear',
timesteps=1000,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3):
betas = make_beta_schedule(
beta_schedule,
timesteps,
linear_start=linear_start,
linear_end=linear_end,
cosine_s=cosine_s)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
@@ -500,33 +638,45 @@ class LowScaleEncoder(nn.Module):
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
assert alphas_cumprod.shape[
0] == self.num_timesteps, 'alphas have to be defined for each timestep'
to_torch = partial(torch.tensor, dtype=torch.float32)
self.register_buffer('betas', to_torch(betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
self.register_buffer('alphas_cumprod_prev',
to_torch(alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
self.register_buffer('sqrt_alphas_cumprod',
to_torch(np.sqrt(alphas_cumprod)))
self.register_buffer('sqrt_one_minus_alphas_cumprod',
to_torch(np.sqrt(1. - alphas_cumprod)))
self.register_buffer('log_one_minus_alphas_cumprod',
to_torch(np.log(1. - alphas_cumprod)))
self.register_buffer('sqrt_recip_alphas_cumprod',
to_torch(np.sqrt(1. / alphas_cumprod)))
self.register_buffer('sqrt_recipm1_alphas_cumprod',
to_torch(np.sqrt(1. / alphas_cumprod - 1)))
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape)
* x_start
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t,
x_start.shape) * noise)
def forward(self, x):
z = self.model.encode(x).sample()
z = z * self.scale_factor
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
noise_level = torch.randint(
0, self.max_noise_level, (x.shape[0], ), device=x.device).long()
z = self.q_sample(z, noise_level)
if self.out_size is not None:
z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest") # TODO: experiment with mode
z = torch.nn.functional.interpolate(
z, size=self.out_size,
mode='nearest') # TODO: experiment with mode
# z = z.repeat_interleave(2, -2).repeat_interleave(2, -1)
return z, noise_level
@@ -535,10 +685,13 @@ class LowScaleEncoder(nn.Module):
return self.model.decode(z)
if __name__ == "__main__":
if __name__ == '__main__':
from ldm.util import count_params
sentences = ["a hedgehog drinking a whiskey", "der mond ist aufgegangen", "Ein Satz mit vielen Sonderzeichen: äöü ß ?! : 'xx-y/@s'"]
model = FrozenT5Embedder(version="google/t5-v1_1-xl").cuda()
sentences = [
'a hedgehog drinking a whiskey', 'der mond ist aufgegangen',
"Ein Satz mit vielen Sonderzeichen: äöü ß ?! : 'xx-y/@s'"
]
model = FrozenT5Embedder(version='google/t5-v1_1-xl').cuda()
count_params(model, True)
z = model(sentences)
print(z.shape)
@@ -548,4 +701,4 @@ if __name__ == "__main__":
z = model(sentences)
print(z.shape)
print("done.")
print('done.')

View File

@@ -1,28 +1,26 @@
"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers"""
import torch
from torch import nn, einsum
import torch.nn.functional as F
from collections import namedtuple
from functools import partial
from inspect import isfunction
from collections import namedtuple
from einops import rearrange, repeat, reduce
import torch
import torch.nn.functional as F
from einops import rearrange, reduce, repeat
from torch import einsum, nn
# constants
DEFAULT_DIM_HEAD = 64
Intermediates = namedtuple('Intermediates', [
'pre_softmax_attn',
'post_softmax_attn'
])
Intermediates = namedtuple('Intermediates',
['pre_softmax_attn', 'post_softmax_attn'])
LayerIntermediates = namedtuple('Intermediates', [
'hiddens',
'attn_intermediates'
])
LayerIntermediates = namedtuple('Intermediates',
['hiddens', 'attn_intermediates'])
class AbsolutePositionalEmbedding(nn.Module):
def __init__(self, dim, max_seq_len):
super().__init__()
self.emb = nn.Embedding(max_seq_len, dim)
@@ -37,13 +35,15 @@ class AbsolutePositionalEmbedding(nn.Module):
class FixedPositionalEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
inv_freq = 1. / (10000**(torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
def forward(self, x, seq_dim=1, offset=0):
t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
t = torch.arange(
x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
return emb[None, :, :]
@@ -51,6 +51,7 @@ class FixedPositionalEmbedding(nn.Module):
# helpers
def exists(val):
return val is not None
@@ -62,20 +63,26 @@ def default(val, d):
def always(val):
def inner(*args, **kwargs):
return val
return inner
def not_equals(val):
def inner(x):
return x != val
return inner
def equals(val):
def inner(x):
return x == val
return inner
@@ -85,6 +92,7 @@ def max_neg_value(tensor):
# keyword argument helpers
def pick_and_pop(keys, d):
values = list(map(lambda key: d.pop(key), keys))
return dict(zip(keys, values))
@@ -96,7 +104,7 @@ def group_dict_by_key(cond, d):
match = bool(cond(key))
ind = int(not match)
return_val[ind][key] = d[key]
return (*return_val,)
return (*return_val, )
def string_begins_with(prefix, str):
@@ -108,13 +116,17 @@ def group_by_key_prefix(prefix, d):
def groupby_prefix_and_trim(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
kwargs_with_prefix, kwargs = group_dict_by_key(
partial(string_begins_with, prefix), d)
kwargs_without_prefix = dict(
map(lambda x: (x[0][len(prefix):], x[1]),
tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs
# classes
class Scale(nn.Module):
def __init__(self, value, fn):
super().__init__()
self.value = value
@@ -126,6 +138,7 @@ class Scale(nn.Module):
class Rezero(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
@@ -137,9 +150,10 @@ class Rezero(nn.Module):
class ScaleNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.scale = dim ** -0.5
self.scale = dim**-0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(1))
@@ -149,9 +163,10 @@ class ScaleNorm(nn.Module):
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-8):
super().__init__()
self.scale = dim ** -0.5
self.scale = dim**-0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(dim))
@@ -161,11 +176,13 @@ class RMSNorm(nn.Module):
class Residual(nn.Module):
def forward(self, x, residual):
return x + residual
class GRUGating(nn.Module):
def __init__(self, dim):
super().__init__()
self.gru = nn.GRUCell(dim, dim)
@@ -173,15 +190,16 @@ class GRUGating(nn.Module):
def forward(self, x, residual):
gated_output = self.gru(
rearrange(x, 'b n d -> (b n) d'),
rearrange(residual, 'b n d -> (b n) d')
)
rearrange(residual, 'b n d -> (b n) d'))
return gated_output.reshape_as(x)
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
@@ -192,20 +210,16 @@ class GEGLU(nn.Module):
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
project_in = nn.Sequential(nn.Linear(
dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
)
self.net = nn.Sequential(project_in, nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out))
def forward(self, x):
return self.net(x)
@@ -213,24 +227,24 @@ class FeedForward(nn.Module):
# attention.
class Attention(nn.Module):
def __init__(
self,
dim,
dim_head=DEFAULT_DIM_HEAD,
heads=8,
causal=False,
mask=None,
talking_heads=False,
sparse_topk=None,
use_entmax15=False,
num_mem_kv=0,
dropout=0.,
on_attn=False
):
def __init__(self,
dim,
dim_head=DEFAULT_DIM_HEAD,
heads=8,
causal=False,
mask=None,
talking_heads=False,
sparse_topk=None,
use_entmax15=False,
num_mem_kv=0,
dropout=0.,
on_attn=False):
super().__init__()
if use_entmax15:
raise NotImplementedError("Check out entmax activation instead of softmax activation!")
self.scale = dim_head ** -0.5
raise NotImplementedError(
'Check out entmax activation instead of softmax activation!')
self.scale = dim_head**-0.5
self.heads = heads
self.causal = causal
self.mask = mask
@@ -252,7 +266,7 @@ class Attention(nn.Module):
self.sparse_topk = sparse_topk
# entmax
#self.attn_fn = entmax15 if use_entmax15 else F.softmax
# self.attn_fn = entmax15 if use_entmax15 else F.softmax
self.attn_fn = F.softmax
# add memory key / values
@@ -263,19 +277,19 @@ class Attention(nn.Module):
# attention on attention
self.attn_on_attn = on_attn
self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim)
self.to_out = nn.Sequential(nn.Linear(
inner_dim, dim
* 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim)
def forward(
self,
x,
context=None,
mask=None,
context_mask=None,
rel_pos=None,
sinusoidal_emb=None,
prev_attn=None,
mem=None
):
def forward(self,
x,
context=None,
mask=None,
context_mask=None,
rel_pos=None,
sinusoidal_emb=None,
prev_attn=None,
mem=None):
b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device
kv_input = default(context, x)
@@ -297,23 +311,29 @@ class Attention(nn.Module):
k = self.to_k(k_input)
v = self.to_v(v_input)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h),
(q, k, v))
input_mask = None
if any(map(exists, (mask, context_mask))):
q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
q_mask = default(mask, lambda: torch.ones(
(b, n), device=device).bool())
k_mask = q_mask if not exists(context) else context_mask
k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool())
k_mask = default(
k_mask, lambda: torch.ones(
(b, k.shape[-2]), device=device).bool())
q_mask = rearrange(q_mask, 'b i -> b () i ()')
k_mask = rearrange(k_mask, 'b j -> b () () j')
input_mask = q_mask * k_mask
if self.num_mem_kv > 0:
mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v))
mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b),
(self.mem_k, self.mem_v))
k = torch.cat((mem_k, k), dim=-2)
v = torch.cat((mem_v, v), dim=-2)
if exists(input_mask):
input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
input_mask = F.pad(
input_mask, (self.num_mem_kv, 0), value=True)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
mask_value = max_neg_value(dots)
@@ -324,7 +344,8 @@ class Attention(nn.Module):
pre_softmax_attn = dots
if talking_heads:
dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()
dots = einsum('b h i j, h k -> b k i j', dots,
self.pre_softmax_proj).contiguous()
if exists(rel_pos):
dots = rel_pos(dots)
@@ -336,7 +357,8 @@ class Attention(nn.Module):
if self.causal:
i, j = dots.shape[-2:]
r = torch.arange(i, device=device)
mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j')
mask = rearrange(r, 'i -> () () i ()') < rearrange(
r, 'j -> () () () j')
mask = F.pad(mask, (j - i, 0), value=False)
dots.masked_fill_(mask, mask_value)
del mask
@@ -354,59 +376,60 @@ class Attention(nn.Module):
attn = self.dropout(attn)
if talking_heads:
attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous()
attn = einsum('b h i j, h k -> b k i j', attn,
self.post_softmax_proj).contiguous()
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
intermediates = Intermediates(
pre_softmax_attn=pre_softmax_attn,
post_softmax_attn=post_softmax_attn
)
post_softmax_attn=post_softmax_attn)
return self.to_out(out), intermediates
class AttentionLayers(nn.Module):
def __init__(
self,
dim,
depth,
heads=8,
causal=False,
cross_attend=False,
only_cross=False,
use_scalenorm=False,
use_rmsnorm=False,
use_rezero=False,
rel_pos_num_buckets=32,
rel_pos_max_distance=128,
position_infused_attn=False,
custom_layers=None,
sandwich_coef=None,
par_ratio=None,
residual_attn=False,
cross_residual_attn=False,
macaron=False,
pre_norm=True,
gate_residual=False,
**kwargs
):
def __init__(self,
dim,
depth,
heads=8,
causal=False,
cross_attend=False,
only_cross=False,
use_scalenorm=False,
use_rmsnorm=False,
use_rezero=False,
rel_pos_num_buckets=32,
rel_pos_max_distance=128,
position_infused_attn=False,
custom_layers=None,
sandwich_coef=None,
par_ratio=None,
residual_attn=False,
cross_residual_attn=False,
macaron=False,
pre_norm=True,
gate_residual=False,
**kwargs):
super().__init__()
ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)
dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
# dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
self.dim = dim
self.depth = depth
self.layers = nn.ModuleList([])
self.has_pos_emb = position_infused_attn
self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
self.pia_pos_emb = FixedPositionalEmbedding(
dim) if position_infused_attn else None
self.rotary_pos_emb = always(None)
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
assert rel_pos_num_buckets <= rel_pos_max_distance, \
'number of relative position buckets must be less than the relative position max distance'
self.rel_pos = None
self.pre_norm = pre_norm
@@ -429,7 +452,7 @@ class AttentionLayers(nn.Module):
default_block = ('a', 'f')
if macaron:
default_block = ('f',) + default_block
default_block = ('f', ) + default_block
if exists(custom_layers):
layer_types = custom_layers
@@ -440,13 +463,17 @@ class AttentionLayers(nn.Module):
par_attn = par_depth // par_ratio
depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
par_width = (depth_cut + depth_cut // par_attn) // par_attn
assert len(default_block) <= par_width, 'default block is too large for par_ratio'
par_block = default_block + ('f',) * (par_width - len(default_block))
assert len(
default_block
) <= par_width, 'default block is too large for par_ratio'
par_block = default_block + ('f', ) * (
par_width - len(default_block))
par_head = par_block * par_attn
layer_types = par_head + ('f',) * (par_depth - len(par_head))
layer_types = par_head + ('f', ) * (par_depth - len(par_head))
elif exists(sandwich_coef):
assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
layer_types = ('a', ) * sandwich_coef + default_block * (
depth - sandwich_coef) + ('f', ) * sandwich_coef
else:
layer_types = default_block * depth
@@ -455,7 +482,8 @@ class AttentionLayers(nn.Module):
for layer_type in self.layer_types:
if layer_type == 'a':
layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
layer = Attention(
dim, heads=heads, causal=causal, **attn_kwargs)
elif layer_type == 'c':
layer = Attention(dim, heads=heads, **attn_kwargs)
elif layer_type == 'f':
@@ -472,21 +500,15 @@ class AttentionLayers(nn.Module):
else:
residual_fn = Residual()
self.layers.append(nn.ModuleList([
norm_fn(),
layer,
residual_fn
]))
self.layers.append(nn.ModuleList([norm_fn(), layer, residual_fn]))
def forward(
self,
x,
context=None,
mask=None,
context_mask=None,
mems=None,
return_hiddens=False
):
def forward(self,
x,
context=None,
mask=None,
context_mask=None,
mems=None,
return_hiddens=False):
hiddens = []
intermediates = []
prev_attn = None
@@ -494,7 +516,8 @@ class AttentionLayers(nn.Module):
mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(
zip(self.layer_types, self.layers)):
is_last = ind == (len(self.layers) - 1)
if layer_type == 'a':
@@ -507,10 +530,20 @@ class AttentionLayers(nn.Module):
x = norm(x)
if layer_type == 'a':
out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos,
prev_attn=prev_attn, mem=layer_mem)
out, inter = block(
x,
mask=mask,
sinusoidal_emb=self.pia_pos_emb,
rel_pos=self.rel_pos,
prev_attn=prev_attn,
mem=layer_mem)
elif layer_type == 'c':
out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn)
out, inter = block(
x,
context=context,
mask=mask,
context_mask=context_mask,
prev_attn=prev_cross_attn)
elif layer_type == 'f':
out = block(x)
@@ -529,9 +562,7 @@ class AttentionLayers(nn.Module):
if return_hiddens:
intermediates = LayerIntermediates(
hiddens=hiddens,
attn_intermediates=intermediates
)
hiddens=hiddens, attn_intermediates=intermediates)
return x, intermediates
@@ -539,28 +570,29 @@ class AttentionLayers(nn.Module):
class Encoder(AttentionLayers):
def __init__(self, **kwargs):
assert 'causal' not in kwargs, 'cannot set causality on encoder'
super().__init__(causal=False, **kwargs)
class TransformerWrapper(nn.Module):
def __init__(
self,
*,
num_tokens,
max_seq_len,
attn_layers,
emb_dim=None,
max_mem_len=0.,
emb_dropout=0.,
num_memory_tokens=None,
tie_embedding=False,
use_pos_emb=True
):
def __init__(self,
*,
num_tokens,
max_seq_len,
attn_layers,
emb_dim=None,
max_mem_len=0.,
emb_dropout=0.,
num_memory_tokens=None,
tie_embedding=False,
use_pos_emb=True):
super().__init__()
assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
assert isinstance(
attn_layers, AttentionLayers
), 'attention layers must be one of Encoder or Decoder'
dim = attn_layers.dim
emb_dim = default(emb_dim, dim)
@@ -571,22 +603,26 @@ class TransformerWrapper(nn.Module):
self.token_emb = nn.Embedding(num_tokens, emb_dim)
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
self.emb_dropout = nn.Dropout(emb_dropout)
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
self.project_emb = nn.Linear(emb_dim,
dim) if emb_dim != dim else nn.Identity()
self.attn_layers = attn_layers
self.norm = nn.LayerNorm(dim)
self.init_()
self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
self.to_logits = nn.Linear(
dim, num_tokens
) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
# memory tokens (like [cls]) from Memory Transformers paper
num_memory_tokens = default(num_memory_tokens, 0)
self.num_memory_tokens = num_memory_tokens
if num_memory_tokens > 0:
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
self.memory_tokens = nn.Parameter(
torch.randn(num_memory_tokens, dim))
# let funnel encoder know number of memory tokens, if specified
if hasattr(attn_layers, 'num_memory_tokens'):
@@ -595,17 +631,16 @@ class TransformerWrapper(nn.Module):
def init_(self):
nn.init.normal_(self.token_emb.weight, std=0.02)
def forward(
self,
x,
return_embeddings=False,
mask=None,
return_mems=False,
return_attn=False,
mems=None,
**kwargs
):
b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
def forward(self,
x,
return_embeddings=False,
mask=None,
return_mems=False,
return_attn=False,
mems=None,
**kwargs):
# b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
b, _, num_mem = *x.shape, self.num_memory_tokens
x = self.token_emb(x)
x += self.pos_emb(x)
x = self.emb_dropout(x)
@@ -620,7 +655,8 @@ class TransformerWrapper(nn.Module):
if exists(mask):
mask = F.pad(mask, (num_mem, 0), value=True)
x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
x, intermediates = self.attn_layers(
x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
x = self.norm(x)
mem, x = x[:, :num_mem], x[:, num_mem:]
@@ -629,13 +665,18 @@ class TransformerWrapper(nn.Module):
if return_mems:
hiddens = intermediates.hiddens
new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens
new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
new_mems = list(
map(lambda pair: torch.cat(pair, dim=-2), zip(
mems, hiddens))) if exists(mems) else hiddens
new_mems = list(
map(lambda t: t[..., -self.max_mem_len:, :].detach(),
new_mems))
return out, new_mems
if return_attn:
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
attn_maps = list(
map(lambda t: t.post_softmax_attn,
intermediates.attn_intermediates))
return out, attn_maps
return out

View File

@@ -1,121 +1,134 @@
# https://github.com/eladrich/pixel2style2pixel
from collections import namedtuple
import torch
from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
"""
ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
"""
from collections import namedtuple
import torch
from torch.nn import (AdaptiveAvgPool2d, BatchNorm2d, Conv2d, MaxPool2d,
Module, PReLU, ReLU, Sequential, Sigmoid)
class Flatten(Module):
def forward(self, input):
return input.view(input.size(0), -1)
def forward(self, input):
return input.view(input.size(0), -1)
def l2_norm(input, axis=1):
norm = torch.norm(input, 2, axis, True)
output = torch.div(input, norm)
return output
norm = torch.norm(input, 2, axis, True)
output = torch.div(input, norm)
return output
class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
""" A named tuple describing a ResNet block. """
""" A named tuple describing a ResNet block. """
def get_block(in_channel, depth, num_units, stride=2):
return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
return [Bottleneck(in_channel, depth, stride)
] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
def get_blocks(num_layers):
if num_layers == 50:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=4),
get_block(in_channel=128, depth=256, num_units=14),
get_block(in_channel=256, depth=512, num_units=3)
]
elif num_layers == 100:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=13),
get_block(in_channel=128, depth=256, num_units=30),
get_block(in_channel=256, depth=512, num_units=3)
]
elif num_layers == 152:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=8),
get_block(in_channel=128, depth=256, num_units=36),
get_block(in_channel=256, depth=512, num_units=3)
]
else:
raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
return blocks
if num_layers == 50:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=4),
get_block(in_channel=128, depth=256, num_units=14),
get_block(in_channel=256, depth=512, num_units=3)
]
elif num_layers == 100:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=13),
get_block(in_channel=128, depth=256, num_units=30),
get_block(in_channel=256, depth=512, num_units=3)
]
elif num_layers == 152:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=8),
get_block(in_channel=128, depth=256, num_units=36),
get_block(in_channel=256, depth=512, num_units=3)
]
else:
raise ValueError(
'Invalid number of layers: {}. Must be one of [50, 100, 152]'.
format(num_layers))
return blocks
class SEModule(Module):
def __init__(self, channels, reduction):
super(SEModule, self).__init__()
self.avg_pool = AdaptiveAvgPool2d(1)
self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
self.relu = ReLU(inplace=True)
self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
self.sigmoid = Sigmoid()
def forward(self, x):
module_input = x
x = self.avg_pool(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.sigmoid(x)
return module_input * x
def __init__(self, channels, reduction):
super(SEModule, self).__init__()
self.avg_pool = AdaptiveAvgPool2d(1)
self.fc1 = Conv2d(
channels,
channels // reduction,
kernel_size=1,
padding=0,
bias=False)
self.relu = ReLU(inplace=True)
self.fc2 = Conv2d(
channels // reduction,
channels,
kernel_size=1,
padding=0,
bias=False)
self.sigmoid = Sigmoid()
def forward(self, x):
module_input = x
x = self.avg_pool(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.sigmoid(x)
return module_input * x
class bottleneck_IR(Module):
def __init__(self, in_channel, depth, stride):
super(bottleneck_IR, self).__init__()
if in_channel == depth:
self.shortcut_layer = MaxPool2d(1, stride)
else:
self.shortcut_layer = Sequential(
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
BatchNorm2d(depth)
)
self.res_layer = Sequential(
BatchNorm2d(in_channel),
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
)
def forward(self, x):
shortcut = self.shortcut_layer(x)
res = self.res_layer(x)
return res + shortcut
def __init__(self, in_channel, depth, stride):
super(bottleneck_IR, self).__init__()
if in_channel == depth:
self.shortcut_layer = MaxPool2d(1, stride)
else:
self.shortcut_layer = Sequential(
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
BatchNorm2d(depth))
self.res_layer = Sequential(
BatchNorm2d(in_channel),
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
PReLU(depth), Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
BatchNorm2d(depth))
def forward(self, x):
shortcut = self.shortcut_layer(x)
res = self.res_layer(x)
return res + shortcut
class bottleneck_IR_SE(Module):
def __init__(self, in_channel, depth, stride):
super(bottleneck_IR_SE, self).__init__()
if in_channel == depth:
self.shortcut_layer = MaxPool2d(1, stride)
else:
self.shortcut_layer = Sequential(
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
BatchNorm2d(depth)
)
self.res_layer = Sequential(
BatchNorm2d(in_channel),
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
PReLU(depth),
Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
BatchNorm2d(depth),
SEModule(depth, 16)
)
def forward(self, x):
shortcut = self.shortcut_layer(x)
res = self.res_layer(x)
return res + shortcut
def __init__(self, in_channel, depth, stride):
super(bottleneck_IR_SE, self).__init__()
if in_channel == depth:
self.shortcut_layer = MaxPool2d(1, stride)
else:
self.shortcut_layer = Sequential(
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
BatchNorm2d(depth))
self.res_layer = Sequential(
BatchNorm2d(in_channel),
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
PReLU(depth), Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
BatchNorm2d(depth), SEModule(depth, 16))
def forward(self, x):
shortcut = self.shortcut_layer(x)
res = self.res_layer(x)
return res + shortcut

View File

@@ -1,22 +1,26 @@
# https://github.com/eladrich/pixel2style2pixel
import torch
from torch import nn
from modelscope.models.cv.image_to_3d.ldm.thirdp.psp.model_irse import Backbone
class IDFeatures(nn.Module):
def __init__(self, model_path):
super(IDFeatures, self).__init__()
print('Loading ResNet ArcFace')
self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
self.facenet.load_state_dict(torch.load(model_path, map_location="cpu"))
self.facenet = Backbone(
input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
self.facenet.load_state_dict(
torch.load(model_path, map_location='cpu'))
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
self.facenet.eval()
def forward(self, x, crop=False):
# Not sure of the image range here
if crop:
x = torch.nn.functional.interpolate(x, (256, 256), mode="area")
x = torch.nn.functional.interpolate(x, (256, 256), mode='area')
x = x[:, :, 35:223, 32:220]
x = self.face_pool(x)
x_feats = self.facenet(x)

View File

@@ -1,86 +1,97 @@
# https://github.com/eladrich/pixel2style2pixel
from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
from modelscope.models.cv.image_to_3d.ldm.thirdp.psp.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
from torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear,
Module, PReLU, Sequential)
"""
Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
"""
from modelscope.models.cv.image_to_3d.ldm.thirdp.psp.helpers import (
Flatten, bottleneck_IR, bottleneck_IR_SE, get_blocks, l2_norm)
class Backbone(Module):
def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
super(Backbone, self).__init__()
assert input_size in [112, 224], "input_size should be 112 or 224"
assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
blocks = get_blocks(num_layers)
if mode == 'ir':
unit_module = bottleneck_IR
elif mode == 'ir_se':
unit_module = bottleneck_IR_SE
self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
BatchNorm2d(64),
PReLU(64))
if input_size == 112:
self.output_layer = Sequential(BatchNorm2d(512),
Dropout(drop_ratio),
Flatten(),
Linear(512 * 7 * 7, 512),
BatchNorm1d(512, affine=affine))
else:
self.output_layer = Sequential(BatchNorm2d(512),
Dropout(drop_ratio),
Flatten(),
Linear(512 * 14 * 14, 512),
BatchNorm1d(512, affine=affine))
"""
Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
"""
modules = []
for block in blocks:
for bottleneck in block:
modules.append(unit_module(bottleneck.in_channel,
bottleneck.depth,
bottleneck.stride))
self.body = Sequential(*modules)
def __init__(self,
input_size,
num_layers,
mode='ir',
drop_ratio=0.4,
affine=True):
super(Backbone, self).__init__()
assert input_size in [112, 224], 'input_size should be 112 or 224'
assert num_layers in [50, 100,
152], 'num_layers should be 50, 100 or 152'
assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
blocks = get_blocks(num_layers)
if mode == 'ir':
unit_module = bottleneck_IR
elif mode == 'ir_se':
unit_module = bottleneck_IR_SE
self.input_layer = Sequential(
Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64),
PReLU(64))
if input_size == 112:
self.output_layer = Sequential(
BatchNorm2d(512), Dropout(drop_ratio), Flatten(),
Linear(512 * 7 * 7, 512), BatchNorm1d(512, affine=affine))
else:
self.output_layer = Sequential(
BatchNorm2d(512), Dropout(drop_ratio), Flatten(),
Linear(512 * 14 * 14, 512), BatchNorm1d(512, affine=affine))
def forward(self, x):
x = self.input_layer(x)
x = self.body(x)
x = self.output_layer(x)
return l2_norm(x)
modules = []
for block in blocks:
for bottleneck in block:
modules.append(
unit_module(bottleneck.in_channel, bottleneck.depth,
bottleneck.stride))
self.body = Sequential(*modules)
def forward(self, x):
x = self.input_layer(x)
x = self.body(x)
x = self.output_layer(x)
return l2_norm(x)
def IR_50(input_size):
"""Constructs a ir-50 model."""
model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
return model
"""Constructs a ir-50 model."""
model = Backbone(
input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
return model
def IR_101(input_size):
"""Constructs a ir-101 model."""
model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
return model
"""Constructs a ir-101 model."""
model = Backbone(
input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
return model
def IR_152(input_size):
"""Constructs a ir-152 model."""
model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
return model
"""Constructs a ir-152 model."""
model = Backbone(
input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
return model
def IR_SE_50(input_size):
"""Constructs a ir_se-50 model."""
model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
return model
"""Constructs a ir_se-50 model."""
model = Backbone(
input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
return model
def IR_SE_101(input_size):
"""Constructs a ir_se-101 model."""
model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
return model
"""Constructs a ir_se-101 model."""
model = Backbone(
input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
return model
def IR_SE_152(input_size):
"""Constructs a ir_se-152 model."""
model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
return model
"""Constructs a ir_se-152 model."""
model = Backbone(
input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
return model

View File

@@ -1,32 +1,24 @@
import importlib
import torchvision
import torch
from torch import optim
import numpy as np
from inspect import isfunction
from PIL import Image, ImageDraw, ImageFont
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
import time
import cv2
import numpy as np
import PIL
import torch
from PIL import Image, ImageDraw, ImageFont
from torch import optim
def pil_rectangle_crop(im):
width, height = im.size # Get dimensions
width, height = im.size # Get dimensions
if width <= height:
left = 0
right = width
top = (height - width)/2
bottom = (height + width)/2
top = (height - width) / 2
bottom = (height + width) / 2
else:
top = 0
bottom = height
left = (width - height) / 2
@@ -36,6 +28,7 @@ def pil_rectangle_crop(im):
im = im.crop((left, top, right, bottom))
return im
def add_margin(pil_img, color=0, size=256):
width, height = pil_img.size
result = Image.new(pil_img.mode, (size, size), color)
@@ -46,16 +39,17 @@ def add_margin(pil_img, color=0, size=256):
def create_carvekit_interface():
from carvekit.api.high import HiInterface
# Check doc strings for more information
interface = HiInterface(object_type="object", # Can be "object" or "hairs-like".
batch_size_seg=5,
batch_size_matting=1,
device='cuda' if torch.cuda.is_available() else 'cpu',
seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
matting_mask_size=2048,
trimap_prob_threshold=231,
trimap_dilation=30,
trimap_erosion_iters=5,
fp16=False)
interface = HiInterface(
object_type='object', # Can be "object" or "hairs-like".
batch_size_seg=5,
batch_size_matting=1,
device='cuda' if torch.cuda.is_available() else 'cpu',
seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
matting_mask_size=2048,
trimap_prob_threshold=231,
trimap_dilation=30,
trimap_erosion_iters=5,
fp16=False)
return interface
@@ -72,17 +66,17 @@ def load_and_preprocess(interface, input_im):
image_without_background = np.array(image_without_background)
est_seg = image_without_background > 127
image = np.array(image)
foreground = est_seg[:, : , -1].astype(np.bool_)
foreground = est_seg[:, :, -1].astype(np.bool_)
image[~foreground] = [255., 255., 255.]
x, y, w, h = cv2.boundingRect(foreground.astype(np.uint8))
image = image[y:y+h, x:x+w, :]
image = image[y:y + h, x:x + w, :]
image = PIL.Image.fromarray(np.array(image))
# resize image such that long edge is 512
image.thumbnail([200, 200], Image.LANCZOS)
image = add_margin(image, (255, 255, 255), size=256)
image = np.array(image)
return image
@@ -92,16 +86,17 @@ def log_txt_as_img(wh, xc, size=10):
b = len(xc)
txts = list()
for bi in range(b):
txt = Image.new("RGB", wh, color="white")
txt = Image.new('RGB', wh, color='white')
draw = ImageDraw.Draw(txt)
font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
nc = int(40 * (wh[0] / 256))
lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
lines = '\n'.join(xc[bi][start:start + nc]
for start in range(0, len(xc[bi]), nc))
try:
draw.text((0, 0), lines, fill="black", font=font)
draw.text((0, 0), lines, fill='black', font=font)
except UnicodeEncodeError:
print("Cant encode string for logging. Skipping.")
print('Cant encode string for logging. Skipping.')
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt)
@@ -117,7 +112,7 @@ def ismap(x):
def isimage(x):
if not isinstance(x,torch.Tensor):
if not isinstance(x, torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
@@ -143,22 +138,24 @@ def mean_flat(tensor):
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
print(
f'{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.'
)
return total_params
def instantiate_from_config(config):
if not "target" in config:
if 'target' not in config:
if config == '__is_first_stage__':
return None
elif config == "__is_unconditional__":
elif config == '__is_unconditional__':
return None
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))
raise KeyError('Expected key `target` to instantiate.')
return get_obj_from_str(config['target'])(**config.get('params', dict()))
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
module, cls = string.rsplit('.', 1)
print(module)
if reload:
module_imp = importlib.import_module(module)
@@ -168,25 +165,42 @@ def get_obj_from_str(string, reload=False):
class AdamWwithEMAandWings(optim.Optimizer):
# credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
ema_power=1., param_names=()):
def __init__(
self,
params,
lr=1.e-3,
betas=(0.9, 0.999),
eps=1.e-8, # TODO: check hyperparameters before using
weight_decay=1.e-2,
amsgrad=False,
ema_decay=0.9999, # ema decay to match previous code
ema_power=1.,
param_names=()): # noqa
"""AdamW that saves EMA versions of the parameters."""
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
raise ValueError('Invalid learning rate: {}'.format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
raise ValueError('Invalid epsilon value: {}'.format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
raise ValueError('Invalid beta parameter at index 0: {}'.format(
betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
raise ValueError('Invalid beta parameter at index 1: {}'.format(
betas[1]))
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
raise ValueError(
'Invalid weight_decay value: {}'.format(weight_decay))
if not 0.0 <= ema_decay <= 1.0:
raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
ema_power=ema_power, param_names=param_names)
raise ValueError('Invalid ema_decay value: {}'.format(ema_decay))
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
amsgrad=amsgrad,
ema_decay=ema_decay,
ema_power=ema_power,
param_names=param_names)
super().__init__(params, defaults)
def __setstate__(self, state):
@@ -212,7 +226,7 @@ class AdamWwithEMAandWings(optim.Optimizer):
exp_avgs = []
exp_avg_sqs = []
ema_params_with_grad = []
state_sums = []
# state_sums = []
max_exp_avg_sqs = []
state_steps = []
amsgrad = group['amsgrad']
@@ -225,7 +239,8 @@ class AdamWwithEMAandWings(optim.Optimizer):
continue
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError('AdamW does not support sparse gradients')
raise RuntimeError(
'AdamW does not support sparse gradients')
grads.append(p.grad)
state = self.state[p]
@@ -234,12 +249,15 @@ class AdamWwithEMAandWings(optim.Optimizer):
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
state['exp_avg'] = torch.zeros_like(
p, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
state['exp_avg_sq'] = torch.zeros_like(
p, memory_format=torch.preserve_format)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
state['max_exp_avg_sq'] = torch.zeros_like(
p, memory_format=torch.preserve_format)
# Exponential moving average of parameter values
state['param_exp_avg'] = p.detach().float().clone()
@@ -255,22 +273,25 @@ class AdamWwithEMAandWings(optim.Optimizer):
# record the step after step update
state_steps.append(state['step'])
optim._functional.adamw(params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad=amsgrad,
beta1=beta1,
beta2=beta2,
lr=group['lr'],
weight_decay=group['weight_decay'],
eps=group['eps'],
maximize=False)
optim._functional.adamw(
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad=amsgrad,
beta1=beta1,
beta2=beta2,
lr=group['lr'],
weight_decay=group['weight_decay'],
eps=group['eps'],
maximize=False)
cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
for param, ema_param in zip(params_with_grad, ema_params_with_grad):
ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
cur_ema_decay = min(ema_decay, 1 - state['step']**-ema_power)
for param, ema_param in zip(params_with_grad,
ema_params_with_grad):
ema_param.mul_(cur_ema_decay).add_(
param.float(), alpha=1 - cur_ema_decay)
return loss
return loss

View File

@@ -1,28 +1,27 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
from typing import Any, Dict
import rembg
import cv2
import numpy as np
import PIL
import rembg
import torch
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from omegaconf import OmegaConf
from PIL import Image
from torchvision.utils import save_image
from omegaconf import OmegaConf
# import modelscope.models.cv.image_to_image_generation.data as data
# import modelscope.models.cv.image_to_image_generation.models as models
# import modelscope.models.cv.image_to_image_generation.ops as ops
from modelscope.metainfo import Pipelines
# from modelscope.models.cv.image_to_3d.model import UNet
# from modelscope.models.cv.image_to_image_generation.models.clip import \
# VisionTransformer
from modelscope.models.cv.image_to_3d.ldm.models.diffusion.sync_dreamer import SyncMultiviewDiffusion
from modelscope.models.cv.image_to_3d.ldm.util import instantiate_from_config, add_margin
from modelscope.models.cv.image_to_3d.ldm.models.diffusion.sync_dreamer import \
SyncMultiviewDiffusion
from modelscope.models.cv.image_to_3d.ldm.util import (add_margin,
instantiate_from_config)
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
@@ -31,23 +30,29 @@ from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
# from modelscope.models.cv.image_to_3d.model import UNet
# from modelscope.models.cv.image_to_image_generation.models.clip import \
# VisionTransformer
logger = get_logger()
# Load Syncdreamer Model
def load_model(cfg, ckpt, strict=True):
config = OmegaConf.load(cfg)
model = instantiate_from_config(config.model)
print(f'loading model from {ckpt} ...')
ckpt = torch.load(ckpt,map_location='cpu')
model.load_state_dict(ckpt['state_dict'],strict=strict)
ckpt = torch.load(ckpt, map_location='cpu')
model.load_state_dict(ckpt['state_dict'], strict=strict)
model = model.cuda().eval()
return model
# Prepare Syncdreamer Input
def prepare_inputs(image_input, elevation_input, crop_size=-1, image_size=256):
image_input[:,:,:3] = image_input[:,:,:3][:,:,::-1]
image_input[:, :, :3] = image_input[:, :, :3][:, :, ::-1]
image_input = Image.fromarray(image_input)
if crop_size!=-1:
if crop_size != -1:
alpha_np = np.asarray(image_input)[:, :, 3]
coords = np.stack(np.nonzero(alpha_np), 1)[:, (1, 0)]
min_x, min_y = np.min(coords, 0)
@@ -59,21 +64,26 @@ def prepare_inputs(image_input, elevation_input, crop_size=-1, image_size=256):
ref_img_ = ref_img_.resize((w_, h_), resample=Image.BICUBIC)
image_input = add_margin(ref_img_, size=image_size)
else:
image_input = add_margin(image_input, size=max(image_input.height, image_input.width))
image_input = image_input.resize((image_size, image_size), resample=Image.BICUBIC)
image_input = add_margin(
image_input, size=max(image_input.height, image_input.width))
image_input = image_input.resize((image_size, image_size),
resample=Image.BICUBIC)
image_input = np.asarray(image_input)
image_input = image_input.astype(np.float32) / 255.0
ref_mask = image_input[:, :, 3:]
image_input[:, :, :3] = image_input[:, :, :3] * ref_mask + 1 - ref_mask # white background
image_input[:, :, :
3] = image_input[:, :, :
3] * ref_mask + 1 - ref_mask # white background
image_input = image_input[:, :, :3] * 2.0 - 1.0
image_input = torch.from_numpy(image_input.astype(np.float32))
elevation_input = torch.from_numpy(np.asarray([np.deg2rad(elevation_input)], np.float32))
return {"input_image": image_input, "input_elevation": elevation_input}
elevation_input = torch.from_numpy(
np.asarray([np.deg2rad(elevation_input)], np.float32))
return {'input_image': image_input, 'input_elevation': elevation_input}
@PIPELINES.register_module(
Tasks.image_to_3d,
module_name=Pipelines.image_to_3d)
Tasks.image_to_3d, module_name=Pipelines.image_to_3d)
class Image23DPipeline(Pipeline):
def __init__(self, model: str, **kwargs):
@@ -91,23 +101,28 @@ class Image23DPipeline(Pipeline):
self._device = torch.device('cuda')
else:
self._device = torch.device('cpu')
ckpt = config_path.replace("configuration.json", "syncdreamer-pretrain.ckpt")
self.model = load_model(config_path.replace("configuration.json", "syncdreamer.yaml"), ckpt).to(self._device)
ckpt = config_path.replace('configuration.json',
'syncdreamer-pretrain.ckpt')
self.model = load_model(
config_path.replace('configuration.json', 'syncdreamer.yaml'),
ckpt).to(self._device)
# os.system("pip install -r {}".format(config_path.replace("configuration.json", "requirements.txt")))
# assert isinstance(self.model, SyncMultiviewDiffusion)
def preprocess(self, input: Input) -> Dict[str, Any]:
result = rembg.remove(Image.open(input))
print(type(result))
img = np.array(result)
img[:,:,:3] = img[:,:,:3][:,:,::-1]
img[:, :, :3] = img[:, :, :3][:, :, ::-1]
# img = cv2.imread(input)
data = prepare_inputs(img, elevation_input=10, crop_size=200, image_size=256)
for k,v in data.items():
data = prepare_inputs(
img, elevation_input=10, crop_size=200, image_size=256)
for k, v in data.items():
data[k] = v.unsqueeze(0).cuda()
data[k] = torch.repeat_interleave(data[k], 1, dim=0) # only one sample
data[k] = torch.repeat_interleave(
data[k], 1, dim=0) # only one sample
return data
@torch.no_grad()
@@ -115,11 +130,11 @@ class Image23DPipeline(Pipeline):
x_sample = self.model.sample(input, 2.0, 8)
B, N, _, H, W = x_sample.shape
x_sample = (torch.clamp(x_sample,max=1.0,min=-1.0) + 1) * 0.5
x_sample = x_sample.permute(0,1,3,4,2).cpu().numpy() * 255
x_sample = (torch.clamp(x_sample, max=1.0, min=-1.0) + 1) * 0.5
x_sample = x_sample.permute(0, 1, 3, 4, 2).cpu().numpy() * 255
x_sample = x_sample.astype(np.uint8)
show_in_im2 = [Image.fromarray(x_sample[0,ni]) for ni in range(N)]
return {'MViews':show_in_im2}
show_in_im2 = [Image.fromarray(x_sample[0, ni]) for ni in range(N)]
return {'MViews': show_in_im2}
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

View File

@@ -3,6 +3,7 @@ import unittest
import numpy as np
from PIL import Image
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.pipelines.base import Pipeline
@@ -24,11 +25,11 @@ class ImageTo3DTest(unittest.TestCase):
def pipeline_inference(self, pipeline: Pipeline, input: str):
result = pipeline(input['input_path'])
np_content = []
for idx,img in enumerate(result['MViews']):
for idx, img in enumerate(result['MViews']):
np_content.append(np.array(result['MViews'][idx]))
np_content = np.concatenate(np_content, axis=1)
Image.fromarray(np_content).save("./concat.png")
Image.fromarray(np_content).save('./concat.png')
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_modelhub(self):
@@ -38,4 +39,4 @@ class ImageTo3DTest(unittest.TestCase):
if __name__ == '__main__':
unittest.main()
unittest.main()