mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 08:17:45 +01:00
fix flake8
This commit is contained in:
@@ -1 +0,0 @@
|
||||
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||
from . import ldm
|
||||
from . import ldm
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import pickle
|
||||
import numpy as np
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from skimage.io import imread
|
||||
|
||||
|
||||
@@ -9,16 +10,18 @@ def save_pickle(data, pkl_path):
|
||||
with open(pkl_path, 'wb') as f:
|
||||
pickle.dump(data, f)
|
||||
|
||||
|
||||
def read_pickle(pkl_path):
|
||||
with open(pkl_path, 'rb') as f:
|
||||
return pickle.load(f)
|
||||
|
||||
|
||||
def draw_epipolar_line(F, img0, img1, pt0, color):
|
||||
h1,w1=img1.shape[:2]
|
||||
h1, w1 = img1.shape[:2]
|
||||
hpt = np.asarray([pt0[0], pt0[1], 1], dtype=np.float32)[:, None]
|
||||
l = F @ hpt
|
||||
l = l[:, 0]
|
||||
a, b, c = l[0], l[1], l[2]
|
||||
_l = F @ hpt
|
||||
_l = _l[:, 0]
|
||||
a, b, c = _l[0], _l[1], _l[2]
|
||||
pt1 = np.asarray([0, -c / b]).astype(np.int32)
|
||||
pt2 = np.asarray([w1, (-a * w1 - c) / b]).astype(np.int32)
|
||||
|
||||
@@ -26,8 +29,9 @@ def draw_epipolar_line(F, img0, img1, pt0, color):
|
||||
img1 = cv2.line(img1, tuple(pt1), tuple(pt2), color, 2)
|
||||
return img0, img1
|
||||
|
||||
def draw_epipolar_lines(F, img0, img1,num=20):
|
||||
img0,img1=img0.copy(),img1.copy()
|
||||
|
||||
def draw_epipolar_lines(F, img0, img1, num=20):
|
||||
img0, img1 = img0.copy(), img1.copy()
|
||||
h0, w0, _ = img0.shape
|
||||
h1, w1, _ = img1.shape
|
||||
|
||||
@@ -42,117 +46,166 @@ def draw_epipolar_lines(F, img0, img1,num=20):
|
||||
|
||||
return img0, img1
|
||||
|
||||
|
||||
def compute_F(K1, K2, Rt0, Rt1=None):
|
||||
if Rt1 is None:
|
||||
R, t = Rt0[:,:3], Rt0[:,3:]
|
||||
R, t = Rt0[:, :3], Rt0[:, 3:]
|
||||
else:
|
||||
Rt = compute_dR_dt(Rt0,Rt1)
|
||||
R, t = Rt[:,:3], Rt[:,3:]
|
||||
A = K1 @ R.T @ t # [3,1]
|
||||
C = np.asarray([[0,-A[2,0],A[1,0]],
|
||||
[A[2,0],0,-A[0,0]],
|
||||
[-A[1,0],A[0,0],0]])
|
||||
Rt = compute_dR_dt(Rt0, Rt1)
|
||||
R, t = Rt[:, :3], Rt[:, 3:]
|
||||
A = K1 @ R.T @ t # [3,1]
|
||||
C = np.asarray([[0, -A[2, 0], A[1, 0]], [A[2, 0], 0, -A[0, 0]],
|
||||
[-A[1, 0], A[0, 0], 0]])
|
||||
F = (np.linalg.inv(K2)).T @ R @ K1.T @ C
|
||||
return F
|
||||
|
||||
|
||||
def compute_dR_dt(Rt0, Rt1):
|
||||
R0, t0 = Rt0[:,:3], Rt0[:,3:]
|
||||
R1, t1 = Rt1[:,:3], Rt1[:,3:]
|
||||
R0, t0 = Rt0[:, :3], Rt0[:, 3:]
|
||||
R1, t1 = Rt1[:, :3], Rt1[:, 3:]
|
||||
dR = np.dot(R1, R0.T)
|
||||
dt = t1 - np.dot(dR, t0)
|
||||
return np.concatenate([dR, dt], -1)
|
||||
|
||||
def concat_images(img0,img1,vert=False):
|
||||
|
||||
def concat_images(img0, img1, vert=False):
|
||||
if not vert:
|
||||
h0,h1=img0.shape[0],img1.shape[0],
|
||||
if h0<h1: img0=cv2.copyMakeBorder(img0,0,h1-h0,0,0,borderType=cv2.BORDER_CONSTANT,value=0)
|
||||
if h1<h0: img1=cv2.copyMakeBorder(img1,0,h0-h1,0,0,borderType=cv2.BORDER_CONSTANT,value=0)
|
||||
h0, h1 = img0.shape[0], img1.shape[0],
|
||||
if h0 < h1:
|
||||
img0 = cv2.copyMakeBorder(
|
||||
img0,
|
||||
0,
|
||||
h1 - h0,
|
||||
0,
|
||||
0,
|
||||
borderType=cv2.BORDER_CONSTANT,
|
||||
value=0)
|
||||
if h1 < h0:
|
||||
img1 = cv2.copyMakeBorder(
|
||||
img1,
|
||||
0,
|
||||
h0 - h1,
|
||||
0,
|
||||
0,
|
||||
borderType=cv2.BORDER_CONSTANT,
|
||||
value=0)
|
||||
img = np.concatenate([img0, img1], axis=1)
|
||||
else:
|
||||
w0,w1=img0.shape[1],img1.shape[1]
|
||||
if w0<w1: img0=cv2.copyMakeBorder(img0,0,0,0,w1-w0,borderType=cv2.BORDER_CONSTANT,value=0)
|
||||
if w1<w0: img1=cv2.copyMakeBorder(img1,0,0,0,w0-w1,borderType=cv2.BORDER_CONSTANT,value=0)
|
||||
w0, w1 = img0.shape[1], img1.shape[1]
|
||||
if w0 < w1:
|
||||
img0 = cv2.copyMakeBorder(
|
||||
img0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
w1 - w0,
|
||||
borderType=cv2.BORDER_CONSTANT,
|
||||
value=0)
|
||||
if w1 < w0:
|
||||
img1 = cv2.copyMakeBorder(
|
||||
img1,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
w0 - w1,
|
||||
borderType=cv2.BORDER_CONSTANT,
|
||||
value=0)
|
||||
img = np.concatenate([img0, img1], axis=0)
|
||||
|
||||
return img
|
||||
|
||||
def concat_images_list(*args,vert=False):
|
||||
if len(args)==1: return args[0]
|
||||
img_out=args[0]
|
||||
|
||||
def concat_images_list(*args, vert=False):
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
img_out = args[0]
|
||||
for img in args[1:]:
|
||||
img_out=concat_images(img_out,img,vert)
|
||||
img_out = concat_images(img_out, img, vert)
|
||||
return img_out
|
||||
|
||||
|
||||
def pose_inverse(pose):
|
||||
R = pose[:,:3].T
|
||||
t = - R @ pose[:,3:]
|
||||
return np.concatenate([R,t],-1)
|
||||
R = pose[:, :3].T
|
||||
t = -R @ pose[:, 3:]
|
||||
return np.concatenate([R, t], -1)
|
||||
|
||||
def project_points(pts,RT,K):
|
||||
pts = np.matmul(pts,RT[:,:3].transpose())+RT[:,3:].transpose()
|
||||
pts = np.matmul(pts,K.transpose())
|
||||
dpt = pts[:,2]
|
||||
mask0 = (np.abs(dpt)<1e-4) & (np.abs(dpt)>0)
|
||||
if np.sum(mask0)>0: dpt[mask0]=1e-4
|
||||
mask1=(np.abs(dpt) > -1e-4) & (np.abs(dpt) < 0)
|
||||
if np.sum(mask1)>0: dpt[mask1]=-1e-4
|
||||
pts2d = pts[:,:2]/dpt[:,None]
|
||||
|
||||
def project_points(pts, RT, K):
|
||||
pts = np.matmul(pts, RT[:, :3].transpose()) + RT[:, 3:].transpose()
|
||||
pts = np.matmul(pts, K.transpose())
|
||||
dpt = pts[:, 2]
|
||||
mask0 = (np.abs(dpt) < 1e-4) & (np.abs(dpt) > 0)
|
||||
if np.sum(mask0) > 0:
|
||||
dpt[mask0] = 1e-4
|
||||
mask1 = (np.abs(dpt) > -1e-4) & (np.abs(dpt) < 0)
|
||||
if np.sum(mask1) > 0:
|
||||
dpt[mask1] = -1e-4
|
||||
pts2d = pts[:, :2] / dpt[:, None]
|
||||
return pts2d, dpt
|
||||
|
||||
|
||||
def draw_keypoints(img, kps, colors=None, radius=2):
|
||||
out_img=img.copy()
|
||||
out_img = img.copy()
|
||||
for pi, pt in enumerate(kps):
|
||||
pt = np.round(pt).astype(np.int32)
|
||||
if colors is not None:
|
||||
color=[int(c) for c in colors[pi]]
|
||||
color = [int(c) for c in colors[pi]]
|
||||
cv2.circle(out_img, tuple(pt), radius, color, -1)
|
||||
else:
|
||||
cv2.circle(out_img, tuple(pt), radius, (0,255,0), -1)
|
||||
cv2.circle(out_img, tuple(pt), radius, (0, 255, 0), -1)
|
||||
return out_img
|
||||
|
||||
|
||||
def output_points(fn,pts,colors=None):
|
||||
def output_points(fn, pts, colors=None):
|
||||
with open(fn, 'w') as f:
|
||||
for pi, pt in enumerate(pts):
|
||||
f.write(f'{pt[0]:.6f} {pt[1]:.6f} {pt[2]:.6f} ')
|
||||
if colors is not None:
|
||||
f.write(f'{int(colors[pi,0])} {int(colors[pi,1])} {int(colors[pi,2])}')
|
||||
f.write(
|
||||
f'{int(colors[pi, 0])} {int(colors[pi, 1])} {int(colors[pi, 2])}'
|
||||
)
|
||||
f.write('\n')
|
||||
|
||||
|
||||
DEPTH_MAX, DEPTH_MIN = 2.4, 0.6
|
||||
DEPTH_VALID_MAX, DEPTH_VALID_MIN = 2.37, 0.63
|
||||
|
||||
|
||||
def read_depth_objaverse(depth_fn):
|
||||
depth = imread(depth_fn)
|
||||
depth = depth.astype(np.float32) / 65535 * (DEPTH_MAX-DEPTH_MIN) + DEPTH_MIN
|
||||
depth = depth.astype(
|
||||
np.float32) / 65535 * (DEPTH_MAX - DEPTH_MIN) + DEPTH_MIN
|
||||
mask = (depth > DEPTH_VALID_MIN) & (depth < DEPTH_VALID_MAX)
|
||||
return depth, mask
|
||||
|
||||
|
||||
def mask_depth_to_pts(mask,depth,K,rgb=None):
|
||||
hs,ws=np.nonzero(mask)
|
||||
depth=depth[hs,ws]
|
||||
pts=np.asarray([ws,hs,depth],np.float32).transpose()
|
||||
pts[:,:2]*=pts[:,2:]
|
||||
def mask_depth_to_pts(mask, depth, K, rgb=None):
|
||||
hs, ws = np.nonzero(mask)
|
||||
depth = depth[hs, ws]
|
||||
pts = np.asarray([ws, hs, depth], np.float32).transpose()
|
||||
pts[:, :2] *= pts[:, 2:]
|
||||
if rgb is not None:
|
||||
return np.dot(pts, np.linalg.inv(K).transpose()), rgb[hs,ws]
|
||||
return np.dot(pts, np.linalg.inv(K).transpose()), rgb[hs, ws]
|
||||
else:
|
||||
return np.dot(pts, np.linalg.inv(K).transpose())
|
||||
|
||||
|
||||
def transform_points_pose(pts, pose):
|
||||
R, t = pose[:, :3], pose[:, 3]
|
||||
if len(pts.shape)==1:
|
||||
return (R @ pts[:,None] + t[:,None])[:,0]
|
||||
return pts @ R.T + t[None,:]
|
||||
if len(pts.shape) == 1:
|
||||
return (R @ pts[:, None] + t[:, None])[:, 0]
|
||||
return pts @ R.T + t[None, :]
|
||||
|
||||
def pose_apply(pose,pts):
|
||||
|
||||
def pose_apply(pose, pts):
|
||||
return transform_points_pose(pts, pose)
|
||||
|
||||
|
||||
def downsample_gaussian_blur(img, ratio):
|
||||
sigma = (1 / ratio) / 3
|
||||
# ksize=np.ceil(2*sigma)
|
||||
ksize = int(np.ceil(((sigma - 0.8) / 0.3 + 1) * 2 + 1))
|
||||
ksize = ksize + 1 if ksize % 2 == 0 else ksize
|
||||
img = cv2.GaussianBlur(img, (ksize, ksize), sigma, borderType=cv2.BORDER_REFLECT101)
|
||||
return img
|
||||
img = cv2.GaussianBlur(
|
||||
img, (ksize, ksize), sigma, borderType=cv2.BORDER_REFLECT101)
|
||||
return img
|
||||
|
||||
@@ -1,34 +1,36 @@
|
||||
import torch
|
||||
import pytorch_lightning as pl
|
||||
import torch.nn.functional as F
|
||||
from contextlib import contextmanager
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
||||
|
||||
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.model import Encoder, Decoder
|
||||
from modelscope.models.cv.image_to_3d.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
||||
|
||||
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.model import (
|
||||
Decoder, Encoder)
|
||||
from modelscope.models.cv.image_to_3d.ldm.modules.distributions.distributions import \
|
||||
DiagonalGaussianDistribution
|
||||
from modelscope.models.cv.image_to_3d.ldm.util import instantiate_from_config
|
||||
|
||||
|
||||
class VQModel(pl.LightningModule):
|
||||
def __init__(self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
n_embed,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key="image",
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
batch_resize_range=None,
|
||||
scheduler_config=None,
|
||||
lr_g_factor=1.0,
|
||||
remap=None,
|
||||
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
||||
use_ema=False
|
||||
):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
n_embed,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key='image',
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
batch_resize_range=None,
|
||||
scheduler_config=None,
|
||||
lr_g_factor=1.0,
|
||||
remap=None,
|
||||
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
||||
use_ema=False):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.n_embed = n_embed
|
||||
@@ -36,24 +38,31 @@ class VQModel(pl.LightningModule):
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
self.loss = instantiate_from_config(lossconfig)
|
||||
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
|
||||
remap=remap,
|
||||
sane_index_shape=sane_index_shape)
|
||||
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||
self.quantize = VectorQuantizer(
|
||||
n_embed,
|
||||
embed_dim,
|
||||
beta=0.25,
|
||||
remap=remap,
|
||||
sane_index_shape=sane_index_shape)
|
||||
self.quant_conv = torch.nn.Conv2d(ddconfig['z_channels'], embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim,
|
||||
ddconfig['z_channels'], 1)
|
||||
if colorize_nlabels is not None:
|
||||
assert type(colorize_nlabels)==int
|
||||
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
||||
assert type(colorize_nlabels) == int
|
||||
self.register_buffer('colorize',
|
||||
torch.randn(3, colorize_nlabels, 1, 1))
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
self.batch_resize_range = batch_resize_range
|
||||
if self.batch_resize_range is not None:
|
||||
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
|
||||
print(
|
||||
f'{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.'
|
||||
)
|
||||
|
||||
self.use_ema = use_ema
|
||||
if self.use_ema:
|
||||
self.model_ema = LitEma(self)
|
||||
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
@@ -66,28 +75,30 @@ class VQModel(pl.LightningModule):
|
||||
self.model_ema.store(self.parameters())
|
||||
self.model_ema.copy_to(self)
|
||||
if context is not None:
|
||||
print(f"{context}: Switched to EMA weights")
|
||||
print(f'{context}: Switched to EMA weights')
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if self.use_ema:
|
||||
self.model_ema.restore(self.parameters())
|
||||
if context is not None:
|
||||
print(f"{context}: Restored training weights")
|
||||
print(f'{context}: Restored training weights')
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||
sd = torch.load(path, map_location='cpu')['state_dict']
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
print('Deleting key {} from state_dict.'.format(k))
|
||||
del sd[k]
|
||||
missing, unexpected = self.load_state_dict(sd, strict=False)
|
||||
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
||||
print(
|
||||
f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys'
|
||||
)
|
||||
if len(missing) > 0:
|
||||
print(f"Missing Keys: {missing}")
|
||||
print(f"Unexpected Keys: {unexpected}")
|
||||
print(f'Missing Keys: {missing}')
|
||||
print(f'Unexpected Keys: {unexpected}')
|
||||
|
||||
def on_train_batch_end(self, *args, **kwargs):
|
||||
if self.use_ema:
|
||||
@@ -115,7 +126,7 @@ class VQModel(pl.LightningModule):
|
||||
return dec
|
||||
|
||||
def forward(self, input, return_pred_indices=False):
|
||||
quant, diff, (_,_,ind) = self.encode(input)
|
||||
quant, diff, (_, _, ind) = self.encode(input)
|
||||
dec = self.decode(quant)
|
||||
if return_pred_indices:
|
||||
return dec, diff, ind
|
||||
@@ -125,7 +136,8 @@ class VQModel(pl.LightningModule):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
||||
x = x.permute(0, 3, 1,
|
||||
2).to(memory_format=torch.contiguous_format).float()
|
||||
if self.batch_resize_range is not None:
|
||||
lower_size = self.batch_resize_range[0]
|
||||
upper_size = self.batch_resize_range[1]
|
||||
@@ -133,9 +145,10 @@ class VQModel(pl.LightningModule):
|
||||
# do the first few batches with max size to avoid later oom
|
||||
new_resize = upper_size
|
||||
else:
|
||||
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
|
||||
new_resize = np.random.choice(
|
||||
np.arange(lower_size, upper_size + 16, 16))
|
||||
if new_resize != x.shape[2]:
|
||||
x = F.interpolate(x, size=new_resize, mode="bicubic")
|
||||
x = F.interpolate(x, size=new_resize, mode='bicubic')
|
||||
x = x.detach()
|
||||
return x
|
||||
|
||||
@@ -147,79 +160,122 @@ class VQModel(pl.LightningModule):
|
||||
|
||||
if optimizer_idx == 0:
|
||||
# autoencode
|
||||
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train",
|
||||
predicted_indices=ind)
|
||||
aeloss, log_dict_ae = self.loss(
|
||||
qloss,
|
||||
x,
|
||||
xrec,
|
||||
optimizer_idx,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='train',
|
||||
predicted_indices=ind)
|
||||
|
||||
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||
self.log_dict(
|
||||
log_dict_ae,
|
||||
prog_bar=False,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=True)
|
||||
return aeloss
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# discriminator
|
||||
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train")
|
||||
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||
discloss, log_dict_disc = self.loss(
|
||||
qloss,
|
||||
x,
|
||||
xrec,
|
||||
optimizer_idx,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='train')
|
||||
self.log_dict(
|
||||
log_dict_disc,
|
||||
prog_bar=False,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=True)
|
||||
return discloss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
log_dict = self._validation_step(batch, batch_idx)
|
||||
with self.ema_scope():
|
||||
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
|
||||
self._validation_step(batch, batch_idx, suffix='_ema')
|
||||
return log_dict
|
||||
|
||||
def _validation_step(self, batch, batch_idx, suffix=""):
|
||||
def _validation_step(self, batch, batch_idx, suffix=''):
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss, ind = self(x, return_pred_indices=True)
|
||||
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="val"+suffix,
|
||||
predicted_indices=ind
|
||||
)
|
||||
aeloss, log_dict_ae = self.loss(
|
||||
qloss,
|
||||
x,
|
||||
xrec,
|
||||
0,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='val' + suffix,
|
||||
predicted_indices=ind)
|
||||
|
||||
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="val"+suffix,
|
||||
predicted_indices=ind
|
||||
)
|
||||
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
|
||||
self.log(f"val{suffix}/rec_loss", rec_loss,
|
||||
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
||||
self.log(f"val{suffix}/aeloss", aeloss,
|
||||
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
||||
discloss, log_dict_disc = self.loss(
|
||||
qloss,
|
||||
x,
|
||||
xrec,
|
||||
1,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='val' + suffix,
|
||||
predicted_indices=ind)
|
||||
rec_loss = log_dict_ae[f'val{suffix}/rec_loss']
|
||||
self.log(
|
||||
f'val{suffix}/rec_loss',
|
||||
rec_loss,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
on_step=False,
|
||||
on_epoch=True,
|
||||
sync_dist=True)
|
||||
self.log(
|
||||
f'val{suffix}/aeloss',
|
||||
aeloss,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
on_step=False,
|
||||
on_epoch=True,
|
||||
sync_dist=True)
|
||||
if version.parse(pl.__version__) >= version.parse('1.4.0'):
|
||||
del log_dict_ae[f"val{suffix}/rec_loss"]
|
||||
del log_dict_ae[f'val{suffix}/rec_loss']
|
||||
self.log_dict(log_dict_ae)
|
||||
self.log_dict(log_dict_disc)
|
||||
return self.log_dict
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr_d = self.learning_rate
|
||||
lr_g = self.lr_g_factor*self.learning_rate
|
||||
print("lr_d", lr_d)
|
||||
print("lr_g", lr_g)
|
||||
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
||||
list(self.decoder.parameters())+
|
||||
list(self.quantize.parameters())+
|
||||
list(self.quant_conv.parameters())+
|
||||
list(self.post_quant_conv.parameters()),
|
||||
lr=lr_g, betas=(0.5, 0.9))
|
||||
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
||||
lr=lr_d, betas=(0.5, 0.9))
|
||||
lr_g = self.lr_g_factor * self.learning_rate
|
||||
print('lr_d', lr_d)
|
||||
print('lr_g', lr_g)
|
||||
opt_ae = torch.optim.Adam(
|
||||
list(self.encoder.parameters()) + list(self.decoder.parameters())
|
||||
+ list(self.quantize.parameters())
|
||||
+ list(self.quant_conv.parameters())
|
||||
+ list(self.post_quant_conv.parameters()),
|
||||
lr=lr_g,
|
||||
betas=(0.5, 0.9))
|
||||
opt_disc = torch.optim.Adam(
|
||||
self.loss.discriminator.parameters(), lr=lr_d, betas=(0.5, 0.9))
|
||||
|
||||
if self.scheduler_config is not None:
|
||||
scheduler = instantiate_from_config(self.scheduler_config)
|
||||
|
||||
print("Setting up LambdaLR scheduler...")
|
||||
print('Setting up LambdaLR scheduler...')
|
||||
scheduler = [
|
||||
{
|
||||
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
|
||||
'scheduler':
|
||||
LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
|
||||
'interval': 'step',
|
||||
'frequency': 1
|
||||
},
|
||||
{
|
||||
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
|
||||
'scheduler':
|
||||
LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
|
||||
'interval': 'step',
|
||||
'frequency': 1
|
||||
},
|
||||
@@ -235,7 +291,7 @@ class VQModel(pl.LightningModule):
|
||||
x = self.get_input(batch, self.image_key)
|
||||
x = x.to(self.device)
|
||||
if only_inputs:
|
||||
log["inputs"] = x
|
||||
log['inputs'] = x
|
||||
return log
|
||||
xrec, _ = self(x)
|
||||
if x.shape[1] > 3:
|
||||
@@ -243,25 +299,28 @@ class VQModel(pl.LightningModule):
|
||||
assert xrec.shape[1] > 3
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log["inputs"] = x
|
||||
log["reconstructions"] = xrec
|
||||
log['inputs'] = x
|
||||
log['reconstructions'] = xrec
|
||||
if plot_ema:
|
||||
with self.ema_scope():
|
||||
xrec_ema, _ = self(x)
|
||||
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
|
||||
log["reconstructions_ema"] = xrec_ema
|
||||
if x.shape[1] > 3:
|
||||
xrec_ema = self.to_rgb(xrec_ema)
|
||||
log['reconstructions_ema'] = xrec_ema
|
||||
return log
|
||||
|
||||
def to_rgb(self, x):
|
||||
assert self.image_key == "segmentation"
|
||||
if not hasattr(self, "colorize"):
|
||||
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||
assert self.image_key == 'segmentation'
|
||||
if not hasattr(self, 'colorize'):
|
||||
self.register_buffer('colorize',
|
||||
torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||
x = F.conv2d(x, weight=self.colorize)
|
||||
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
||||
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
|
||||
return x
|
||||
|
||||
|
||||
class VQModelInterface(VQModel):
|
||||
|
||||
def __init__(self, embed_dim, *args, **kwargs):
|
||||
super().__init__(embed_dim=embed_dim, *args, **kwargs)
|
||||
self.embed_dim = embed_dim
|
||||
@@ -283,43 +342,48 @@ class VQModelInterface(VQModel):
|
||||
|
||||
|
||||
class AutoencoderKL(pl.LightningModule):
|
||||
def __init__(self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key="image",
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key='image',
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.image_key = image_key
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
self.loss = instantiate_from_config(lossconfig)
|
||||
assert ddconfig["double_z"]
|
||||
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||
assert ddconfig['double_z']
|
||||
self.quant_conv = torch.nn.Conv2d(2 * ddconfig['z_channels'],
|
||||
2 * embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim,
|
||||
ddconfig['z_channels'], 1)
|
||||
self.embed_dim = embed_dim
|
||||
if colorize_nlabels is not None:
|
||||
assert type(colorize_nlabels)==int
|
||||
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
||||
assert type(colorize_nlabels) == int
|
||||
self.register_buffer('colorize',
|
||||
torch.randn(3, colorize_nlabels, 1, 1))
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||
sd = torch.load(path, map_location='cpu')['state_dict']
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
print('Deleting key {} from state_dict.'.format(k))
|
||||
del sd[k]
|
||||
self.load_state_dict(sd, strict=False)
|
||||
print(f"Restored from {path}")
|
||||
print(f'Restored from {path}')
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
@@ -345,7 +409,8 @@ class AutoencoderKL(pl.LightningModule):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
||||
x = x.permute(0, 3, 1,
|
||||
2).to(memory_format=torch.contiguous_format).float()
|
||||
return x
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
@@ -354,44 +419,91 @@ class AutoencoderKL(pl.LightningModule):
|
||||
|
||||
if optimizer_idx == 0:
|
||||
# train encoder+decoder+logvar
|
||||
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train")
|
||||
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||||
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
||||
aeloss, log_dict_ae = self.loss(
|
||||
inputs,
|
||||
reconstructions,
|
||||
posterior,
|
||||
optimizer_idx,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='train')
|
||||
self.log(
|
||||
'aeloss',
|
||||
aeloss,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=True)
|
||||
self.log_dict(
|
||||
log_dict_ae,
|
||||
prog_bar=False,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=False)
|
||||
return aeloss
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# train the discriminator
|
||||
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train")
|
||||
discloss, log_dict_disc = self.loss(
|
||||
inputs,
|
||||
reconstructions,
|
||||
posterior,
|
||||
optimizer_idx,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='train')
|
||||
|
||||
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||||
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
||||
self.log(
|
||||
'discloss',
|
||||
discloss,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=True)
|
||||
self.log_dict(
|
||||
log_dict_disc,
|
||||
prog_bar=False,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=False)
|
||||
return discloss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
inputs = self.get_input(batch, self.image_key)
|
||||
reconstructions, posterior = self(inputs)
|
||||
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="val")
|
||||
aeloss, log_dict_ae = self.loss(
|
||||
inputs,
|
||||
reconstructions,
|
||||
posterior,
|
||||
0,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='val')
|
||||
|
||||
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="val")
|
||||
discloss, log_dict_disc = self.loss(
|
||||
inputs,
|
||||
reconstructions,
|
||||
posterior,
|
||||
1,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='val')
|
||||
|
||||
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
|
||||
self.log('val/rec_loss', log_dict_ae['val/rec_loss'])
|
||||
self.log_dict(log_dict_ae)
|
||||
self.log_dict(log_dict_disc)
|
||||
return self.log_dict
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr = self.learning_rate
|
||||
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
||||
list(self.decoder.parameters())+
|
||||
list(self.quant_conv.parameters())+
|
||||
list(self.post_quant_conv.parameters()),
|
||||
lr=lr, betas=(0.5, 0.9))
|
||||
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
||||
lr=lr, betas=(0.5, 0.9))
|
||||
opt_ae = torch.optim.Adam(
|
||||
list(self.encoder.parameters()) + list(self.decoder.parameters())
|
||||
+ list(self.quant_conv.parameters())
|
||||
+ list(self.post_quant_conv.parameters()),
|
||||
lr=lr,
|
||||
betas=(0.5, 0.9))
|
||||
opt_disc = torch.optim.Adam(
|
||||
self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9))
|
||||
return [opt_ae, opt_disc], []
|
||||
|
||||
def get_last_layer(self):
|
||||
@@ -409,21 +521,23 @@ class AutoencoderKL(pl.LightningModule):
|
||||
assert xrec.shape[1] > 3
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
|
||||
log["reconstructions"] = xrec
|
||||
log["inputs"] = x
|
||||
log['samples'] = self.decode(torch.randn_like(posterior.sample()))
|
||||
log['reconstructions'] = xrec
|
||||
log['inputs'] = x
|
||||
return log
|
||||
|
||||
def to_rgb(self, x):
|
||||
assert self.image_key == "segmentation"
|
||||
if not hasattr(self, "colorize"):
|
||||
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||
assert self.image_key == 'segmentation'
|
||||
if not hasattr(self, 'colorize'):
|
||||
self.register_buffer('colorize',
|
||||
torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||
x = F.conv2d(x, weight=self.colorize)
|
||||
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
||||
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
|
||||
return x
|
||||
|
||||
|
||||
class IdentityFirstStage(torch.nn.Module):
|
||||
|
||||
def __init__(self, *args, vq_interface=False, **kwargs):
|
||||
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
|
||||
super().__init__()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,17 +1,26 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from modelscope.models.cv.image_to_3d.ldm.modules.attention import default, zero_module, checkpoint
|
||||
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.openaimodel import UNetModel
|
||||
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.util import timestep_embedding
|
||||
import modelscope.models.cv.image_to_3d.ldm.modules.attention as attention
|
||||
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.openaimodel import \
|
||||
UNetModel
|
||||
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.util import \
|
||||
timestep_embedding
|
||||
|
||||
|
||||
class DepthAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim, heads, dim_head, output_bias=True):
|
||||
|
||||
def __init__(self,
|
||||
query_dim,
|
||||
context_dim,
|
||||
heads,
|
||||
dim_head,
|
||||
output_bias=True):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
context_dim = attention.default(context_dim, query_dim)
|
||||
|
||||
self.scale = dim_head ** -0.5
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
@@ -34,21 +43,27 @@ class DepthAttention(nn.Module):
|
||||
b, _, h, w = x.shape
|
||||
b, _, d, h, w = context.shape
|
||||
|
||||
q = self.to_q(x).reshape(b,hn,hd,h,w) # b,t,h,w
|
||||
k = self.to_k(context).reshape(b,hn,hd,d,h,w) # b,t,d,h,w
|
||||
v = self.to_v(context).reshape(b,hn,hd,d,h,w) # b,t,d,h,w
|
||||
q = self.to_q(x).reshape(b, hn, hd, h, w) # b,t,h,w
|
||||
k = self.to_k(context).reshape(b, hn, hd, d, h, w) # b,t,d,h,w
|
||||
v = self.to_v(context).reshape(b, hn, hd, d, h, w) # b,t,d,h,w
|
||||
|
||||
sim = torch.sum(q.unsqueeze(3) * k, 2) * self.scale # b,hn,d,h,w
|
||||
sim = torch.sum(q.unsqueeze(3) * k, 2) * self.scale # b,hn,d,h,w
|
||||
attn = sim.softmax(dim=2)
|
||||
|
||||
# b,hn,hd,d,h,w * b,hn,1,d,h,w
|
||||
out = torch.sum(v * attn.unsqueeze(2), 3) # b,hn,hd,h,w
|
||||
out = out.reshape(b,hn*hd,h,w)
|
||||
out = torch.sum(v * attn.unsqueeze(2), 3) # b,hn,hd,h,w
|
||||
out = out.reshape(b, hn * hd, h, w)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class DepthTransformer(nn.Module):
|
||||
def __init__(self, dim, n_heads, d_head, context_dim=None, checkpoint=True):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
context_dim=None,
|
||||
checkpoint=True):
|
||||
super().__init__()
|
||||
inner_dim = n_heads * d_head
|
||||
self.proj_in = nn.Sequential(
|
||||
@@ -57,23 +72,33 @@ class DepthTransformer(nn.Module):
|
||||
nn.SiLU(True),
|
||||
)
|
||||
self.proj_context = nn.Sequential(
|
||||
nn.Conv3d(context_dim, context_dim, 1, 1, bias=False), # no bias
|
||||
nn.Conv3d(context_dim, context_dim, 1, 1, bias=False), # no bias
|
||||
nn.GroupNorm(8, context_dim),
|
||||
nn.ReLU(True), # only relu, because we want input is 0, output is 0
|
||||
nn.ReLU(
|
||||
True), # only relu, because we want input is 0, output is 0
|
||||
)
|
||||
self.depth_attn = DepthAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, context_dim=context_dim, output_bias=False) # is a self-attention if not self.disable_self_attn
|
||||
self.depth_attn = DepthAttention(
|
||||
query_dim=inner_dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
context_dim=context_dim,
|
||||
output_bias=False
|
||||
) # is a self-attention if not self.disable_self_attn
|
||||
self.proj_out = nn.Sequential(
|
||||
nn.GroupNorm(8, inner_dim),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(inner_dim, inner_dim, 3, 1, 1, bias=False),
|
||||
nn.GroupNorm(8, inner_dim),
|
||||
nn.ReLU(True),
|
||||
zero_module(nn.Conv2d(inner_dim, dim, 3, 1, 1, bias=False)),
|
||||
attention.zero_module(
|
||||
nn.Conv2d(inner_dim, dim, 3, 1, 1, bias=False)),
|
||||
)
|
||||
self.checkpoint = checkpoint
|
||||
self.checkpoint = attention.checkpoint
|
||||
|
||||
def forward(self, x, context=None):
|
||||
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
||||
|
||||
return attention.checkpoint(self._forward, (x, context),
|
||||
self.parameters(), self.checkpoint) # noqa
|
||||
|
||||
def _forward(self, x, context):
|
||||
x_in = x
|
||||
@@ -85,38 +110,65 @@ class DepthTransformer(nn.Module):
|
||||
|
||||
|
||||
class DepthWiseAttention(UNetModel):
|
||||
def __init__(self, volume_dims=(5,16,32,64), *args, **kwargs):
|
||||
|
||||
def __init__(self, volume_dims=(5, 16, 32, 64), *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# num_heads = 4
|
||||
model_channels = kwargs['model_channels']
|
||||
channel_mult = kwargs['channel_mult']
|
||||
d0,d1,d2,d3 = volume_dims
|
||||
d0, d1, d2, d3 = volume_dims
|
||||
|
||||
# 4
|
||||
ch = model_channels*channel_mult[2]
|
||||
self.middle_conditions = DepthTransformer(ch, 4, d3 // 2, context_dim=d3)
|
||||
ch = model_channels * channel_mult[2]
|
||||
self.middle_conditions = DepthTransformer(
|
||||
ch, 4, d3 // 2, context_dim=d3)
|
||||
|
||||
self.output_conditions=nn.ModuleList()
|
||||
self.output_b2c = {3:0,4:1,5:2,6:3,7:4,8:5,9:6,10:7,11:8}
|
||||
self.output_conditions = nn.ModuleList()
|
||||
self.output_b2c = {
|
||||
3: 0,
|
||||
4: 1,
|
||||
5: 2,
|
||||
6: 3,
|
||||
7: 4,
|
||||
8: 5,
|
||||
9: 6,
|
||||
10: 7,
|
||||
11: 8
|
||||
}
|
||||
# 8
|
||||
ch = model_channels*channel_mult[2]
|
||||
self.output_conditions.append(DepthTransformer(ch, 4, d2 // 2, context_dim=d2)) # 0
|
||||
self.output_conditions.append(DepthTransformer(ch, 4, d2 // 2, context_dim=d2)) # 1
|
||||
ch = model_channels * channel_mult[2]
|
||||
self.output_conditions.append(
|
||||
DepthTransformer(ch, 4, d2 // 2, context_dim=d2)) # 0
|
||||
self.output_conditions.append(
|
||||
DepthTransformer(ch, 4, d2 // 2, context_dim=d2)) # 1
|
||||
# 16
|
||||
self.output_conditions.append(DepthTransformer(ch, 4, d1 // 2, context_dim=d1)) # 2
|
||||
ch = model_channels*channel_mult[1]
|
||||
self.output_conditions.append(DepthTransformer(ch, 4, d1 // 2, context_dim=d1)) # 3
|
||||
self.output_conditions.append(DepthTransformer(ch, 4, d1 // 2, context_dim=d1)) # 4
|
||||
self.output_conditions.append(
|
||||
DepthTransformer(ch, 4, d1 // 2, context_dim=d1)) # 2
|
||||
ch = model_channels * channel_mult[1]
|
||||
self.output_conditions.append(
|
||||
DepthTransformer(ch, 4, d1 // 2, context_dim=d1)) # 3
|
||||
self.output_conditions.append(
|
||||
DepthTransformer(ch, 4, d1 // 2, context_dim=d1)) # 4
|
||||
# 32
|
||||
self.output_conditions.append(DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 5
|
||||
ch = model_channels*channel_mult[0]
|
||||
self.output_conditions.append(DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 6
|
||||
self.output_conditions.append(DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 7
|
||||
self.output_conditions.append(DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 8
|
||||
self.output_conditions.append(
|
||||
DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 5
|
||||
ch = model_channels * channel_mult[0]
|
||||
self.output_conditions.append(
|
||||
DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 6
|
||||
self.output_conditions.append(
|
||||
DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 7
|
||||
self.output_conditions.append(
|
||||
DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 8
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, source_dict=None, **kwargs):
|
||||
def forward(self,
|
||||
x,
|
||||
timesteps=None,
|
||||
context=None,
|
||||
source_dict=None,
|
||||
**kwargs):
|
||||
hs = []
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
||||
t_emb = timestep_embedding(
|
||||
timesteps, self.model_channels, repeat_only=False)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
h = x.type(self.dtype)
|
||||
@@ -138,5 +190,6 @@ class DepthWiseAttention(UNetModel):
|
||||
return self.out(h)
|
||||
|
||||
def get_trainable_parameters(self):
|
||||
paras = [para for para in self.middle_conditions.parameters()] + [para for para in self.output_conditions.parameters()]
|
||||
paras = [para for para in self.middle_conditions.parameters()
|
||||
] + [para for para in self.output_conditions.parameters()]
|
||||
return paras
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Image2DResBlockWithTV(nn.Module):
|
||||
|
||||
def __init__(self, dim, tdim, vdim):
|
||||
super().__init__()
|
||||
norm = lambda c: nn.GroupNorm(8, c)
|
||||
norm = lambda c: nn.GroupNorm(8, c) # noqa
|
||||
self.time_embed = nn.Conv2d(tdim, dim, 1, 1)
|
||||
self.view_embed = nn.Conv2d(vdim, dim, 1, 1)
|
||||
self.conv = nn.Sequential(
|
||||
@@ -17,22 +19,28 @@ class Image2DResBlockWithTV(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x, t, v):
|
||||
return x+self.conv(x+self.time_embed(t)+self.view_embed(v))
|
||||
return x + self.conv(x + self.time_embed(t) + self.view_embed(v))
|
||||
|
||||
|
||||
class NoisyTargetViewEncoder(nn.Module):
|
||||
def __init__(self, time_embed_dim, viewpoint_dim, run_dim=16, output_dim=8):
|
||||
|
||||
def __init__(self,
|
||||
time_embed_dim,
|
||||
viewpoint_dim,
|
||||
run_dim=16,
|
||||
output_dim=8):
|
||||
super().__init__()
|
||||
|
||||
self.init_conv = nn.Conv2d(4, run_dim, 3, 1, 1)
|
||||
self.out_conv0 = Image2DResBlockWithTV(run_dim, time_embed_dim, viewpoint_dim)
|
||||
self.out_conv1 = Image2DResBlockWithTV(run_dim, time_embed_dim, viewpoint_dim)
|
||||
self.out_conv2 = Image2DResBlockWithTV(run_dim, time_embed_dim, viewpoint_dim)
|
||||
self.out_conv0 = Image2DResBlockWithTV(run_dim, time_embed_dim,
|
||||
viewpoint_dim)
|
||||
self.out_conv1 = Image2DResBlockWithTV(run_dim, time_embed_dim,
|
||||
viewpoint_dim)
|
||||
self.out_conv2 = Image2DResBlockWithTV(run_dim, time_embed_dim,
|
||||
viewpoint_dim)
|
||||
self.final_out = nn.Sequential(
|
||||
nn.GroupNorm(8, run_dim),
|
||||
nn.SiLU(True),
|
||||
nn.Conv2d(run_dim, output_dim, 3, 1, 1)
|
||||
)
|
||||
nn.GroupNorm(8, run_dim), nn.SiLU(True),
|
||||
nn.Conv2d(run_dim, output_dim, 3, 1, 1))
|
||||
|
||||
def forward(self, x, t, v):
|
||||
B, DT = t.shape
|
||||
@@ -47,23 +55,33 @@ class NoisyTargetViewEncoder(nn.Module):
|
||||
x = self.final_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class SpatialUpTimeBlock(nn.Module):
|
||||
|
||||
def __init__(self, x_in_dim, t_in_dim, out_dim):
|
||||
super().__init__()
|
||||
norm_act = lambda c: nn.GroupNorm(8, c)
|
||||
norm_act = lambda c: nn.GroupNorm(8, c) # noqa
|
||||
self.t_conv = nn.Conv3d(t_in_dim, x_in_dim, 1, 1) # 16
|
||||
self.norm = norm_act(x_in_dim)
|
||||
self.silu = nn.SiLU(True)
|
||||
self.conv = nn.ConvTranspose3d(x_in_dim, out_dim, kernel_size=3, padding=1, output_padding=1, stride=2)
|
||||
self.conv = nn.ConvTranspose3d(
|
||||
x_in_dim,
|
||||
out_dim,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
output_padding=1,
|
||||
stride=2)
|
||||
|
||||
def forward(self, x, t):
|
||||
x = x + self.t_conv(t)
|
||||
return self.conv(self.silu(self.norm(x)))
|
||||
|
||||
|
||||
class SpatialTimeBlock(nn.Module):
|
||||
|
||||
def __init__(self, x_in_dim, t_in_dim, out_dim, stride):
|
||||
super().__init__()
|
||||
norm_act = lambda c: nn.GroupNorm(8, c)
|
||||
norm_act = lambda c: nn.GroupNorm(8, c) # noqa
|
||||
self.t_conv = nn.Conv3d(t_in_dim, x_in_dim, 1, 1) # 16
|
||||
self.bn = norm_act(x_in_dim)
|
||||
self.silu = nn.SiLU(True)
|
||||
@@ -73,61 +91,65 @@ class SpatialTimeBlock(nn.Module):
|
||||
x = x + self.t_conv(t)
|
||||
return self.conv(self.silu(self.bn(x)))
|
||||
|
||||
|
||||
class SpatialTime3DNet(nn.Module):
|
||||
def __init__(self, time_dim=256, input_dim=128, dims=(32, 64, 128, 256)):
|
||||
super().__init__()
|
||||
d0, d1, d2, d3 = dims
|
||||
dt = time_dim
|
||||
|
||||
self.init_conv = nn.Conv3d(input_dim, d0, 3, 1, 1) # 32
|
||||
self.conv0 = SpatialTimeBlock(d0, dt, d0, stride=1)
|
||||
def __init__(self, time_dim=256, input_dim=128, dims=(32, 64, 128, 256)):
|
||||
super().__init__()
|
||||
d0, d1, d2, d3 = dims
|
||||
dt = time_dim
|
||||
|
||||
self.conv1 = SpatialTimeBlock(d0, dt, d1, stride=2)
|
||||
self.conv2_0 = SpatialTimeBlock(d1, dt, d1, stride=1)
|
||||
self.conv2_1 = SpatialTimeBlock(d1, dt, d1, stride=1)
|
||||
self.init_conv = nn.Conv3d(input_dim, d0, 3, 1, 1) # 32
|
||||
self.conv0 = SpatialTimeBlock(d0, dt, d0, stride=1)
|
||||
|
||||
self.conv3 = SpatialTimeBlock(d1, dt, d2, stride=2)
|
||||
self.conv4_0 = SpatialTimeBlock(d2, dt, d2, stride=1)
|
||||
self.conv4_1 = SpatialTimeBlock(d2, dt, d2, stride=1)
|
||||
self.conv1 = SpatialTimeBlock(d0, dt, d1, stride=2)
|
||||
self.conv2_0 = SpatialTimeBlock(d1, dt, d1, stride=1)
|
||||
self.conv2_1 = SpatialTimeBlock(d1, dt, d1, stride=1)
|
||||
|
||||
self.conv5 = SpatialTimeBlock(d2, dt, d3, stride=2)
|
||||
self.conv6_0 = SpatialTimeBlock(d3, dt, d3, stride=1)
|
||||
self.conv6_1 = SpatialTimeBlock(d3, dt, d3, stride=1)
|
||||
self.conv3 = SpatialTimeBlock(d1, dt, d2, stride=2)
|
||||
self.conv4_0 = SpatialTimeBlock(d2, dt, d2, stride=1)
|
||||
self.conv4_1 = SpatialTimeBlock(d2, dt, d2, stride=1)
|
||||
|
||||
self.conv7 = SpatialUpTimeBlock(d3, dt, d2)
|
||||
self.conv8 = SpatialUpTimeBlock(d2, dt, d1)
|
||||
self.conv9 = SpatialUpTimeBlock(d1, dt, d0)
|
||||
self.conv5 = SpatialTimeBlock(d2, dt, d3, stride=2)
|
||||
self.conv6_0 = SpatialTimeBlock(d3, dt, d3, stride=1)
|
||||
self.conv6_1 = SpatialTimeBlock(d3, dt, d3, stride=1)
|
||||
|
||||
def forward(self, x, t):
|
||||
B, C = t.shape
|
||||
t = t.view(B, C, 1, 1, 1)
|
||||
self.conv7 = SpatialUpTimeBlock(d3, dt, d2)
|
||||
self.conv8 = SpatialUpTimeBlock(d2, dt, d1)
|
||||
self.conv9 = SpatialUpTimeBlock(d1, dt, d0)
|
||||
|
||||
x = self.init_conv(x)
|
||||
conv0 = self.conv0(x, t)
|
||||
def forward(self, x, t):
|
||||
B, C = t.shape
|
||||
t = t.view(B, C, 1, 1, 1)
|
||||
|
||||
x = self.conv1(conv0, t)
|
||||
x = self.conv2_0(x, t)
|
||||
conv2 = self.conv2_1(x, t)
|
||||
x = self.init_conv(x)
|
||||
conv0 = self.conv0(x, t)
|
||||
|
||||
x = self.conv3(conv2, t)
|
||||
x = self.conv4_0(x, t)
|
||||
conv4 = self.conv4_1(x, t)
|
||||
x = self.conv1(conv0, t)
|
||||
x = self.conv2_0(x, t)
|
||||
conv2 = self.conv2_1(x, t)
|
||||
|
||||
x = self.conv5(conv4, t)
|
||||
x = self.conv6_0(x, t)
|
||||
x = self.conv6_1(x, t)
|
||||
x = self.conv3(conv2, t)
|
||||
x = self.conv4_0(x, t)
|
||||
conv4 = self.conv4_1(x, t)
|
||||
|
||||
x = self.conv5(conv4, t)
|
||||
x = self.conv6_0(x, t)
|
||||
x = self.conv6_1(x, t)
|
||||
|
||||
x = conv4 + self.conv7(x, t)
|
||||
x = conv2 + self.conv8(x, t)
|
||||
x = conv0 + self.conv9(x, t)
|
||||
return x
|
||||
|
||||
x = conv4 + self.conv7(x, t)
|
||||
x = conv2 + self.conv8(x, t)
|
||||
x = conv0 + self.conv9(x, t)
|
||||
return x
|
||||
|
||||
class FrustumTVBlock(nn.Module):
|
||||
|
||||
def __init__(self, x_dim, t_dim, v_dim, out_dim, stride):
|
||||
super().__init__()
|
||||
norm_act = lambda c: nn.GroupNorm(8, c)
|
||||
self.t_conv = nn.Conv3d(t_dim, x_dim, 1, 1) # 16
|
||||
self.v_conv = nn.Conv3d(v_dim, x_dim, 1, 1) # 16
|
||||
norm_act = lambda c: nn.GroupNorm(8, c) # noqa
|
||||
self.t_conv = nn.Conv3d(t_dim, x_dim, 1, 1) # 16
|
||||
self.v_conv = nn.Conv3d(v_dim, x_dim, 1, 1) # 16
|
||||
self.bn = norm_act(x_dim)
|
||||
self.silu = nn.SiLU(True)
|
||||
self.conv = nn.Conv3d(x_dim, out_dim, 3, stride=stride, padding=1)
|
||||
@@ -136,24 +158,34 @@ class FrustumTVBlock(nn.Module):
|
||||
x = x + self.t_conv(t) + self.v_conv(v)
|
||||
return self.conv(self.silu(self.bn(x)))
|
||||
|
||||
|
||||
class FrustumTVUpBlock(nn.Module):
|
||||
|
||||
def __init__(self, x_dim, t_dim, v_dim, out_dim):
|
||||
super().__init__()
|
||||
norm_act = lambda c: nn.GroupNorm(8, c)
|
||||
self.t_conv = nn.Conv3d(t_dim, x_dim, 1, 1) # 16
|
||||
self.v_conv = nn.Conv3d(v_dim, x_dim, 1, 1) # 16
|
||||
norm_act = lambda c: nn.GroupNorm(8, c) # noqa
|
||||
self.t_conv = nn.Conv3d(t_dim, x_dim, 1, 1) # 16
|
||||
self.v_conv = nn.Conv3d(v_dim, x_dim, 1, 1) # 16
|
||||
self.norm = norm_act(x_dim)
|
||||
self.silu = nn.SiLU(True)
|
||||
self.conv = nn.ConvTranspose3d(x_dim, out_dim, kernel_size=3, padding=1, output_padding=1, stride=2)
|
||||
self.conv = nn.ConvTranspose3d(
|
||||
x_dim,
|
||||
out_dim,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
output_padding=1,
|
||||
stride=2)
|
||||
|
||||
def forward(self, x, t, v):
|
||||
x = x + self.t_conv(t) + self.v_conv(v)
|
||||
return self.conv(self.silu(self.norm(x)))
|
||||
|
||||
|
||||
class FrustumTV3DNet(nn.Module):
|
||||
|
||||
def __init__(self, in_dim, t_dim, v_dim, dims=(32, 64, 128, 256)):
|
||||
super().__init__()
|
||||
self.conv0 = nn.Conv3d(in_dim, dims[0], 3, 1, 1) # 32
|
||||
self.conv0 = nn.Conv3d(in_dim, dims[0], 3, 1, 1) # 32
|
||||
|
||||
self.conv1 = FrustumTVBlock(dims[0], t_dim, v_dim, dims[1], 2)
|
||||
self.conv2 = FrustumTVBlock(dims[1], t_dim, v_dim, dims[1], 1)
|
||||
@@ -169,10 +201,10 @@ class FrustumTV3DNet(nn.Module):
|
||||
self.up2 = FrustumTVUpBlock(dims[1], t_dim, v_dim, dims[0])
|
||||
|
||||
def forward(self, x, t, v):
|
||||
B,DT = t.shape
|
||||
t = t.view(B,DT,1,1,1)
|
||||
B,DV = v.shape
|
||||
v = v.view(B,DV,1,1,1)
|
||||
B, DT = t.shape
|
||||
t = t.view(B, DT, 1, 1, 1)
|
||||
B, DV = v.shape
|
||||
v = v.view(B, DV, 1, 1, 1)
|
||||
|
||||
b, _, d, h, w = x.shape
|
||||
x0 = self.conv0(x)
|
||||
@@ -183,4 +215,4 @@ class FrustumTV3DNet(nn.Module):
|
||||
x2 = self.up0(x3, t, v) + x2
|
||||
x1 = self.up1(x2, t, v) + x1
|
||||
x0 = self.up2(x1, t, v) + x0
|
||||
return {w: x0, w//2: x1, w//4: x2, w//8: x3}
|
||||
return {w: x0, w // 2: x1, w // 4: x2, w // 8: x3}
|
||||
|
||||
@@ -10,13 +10,13 @@ def project_and_normalize(ref_grid, src_proj, length):
|
||||
@param length: int
|
||||
@return: b, n, 2
|
||||
"""
|
||||
src_grid = src_proj[:, :3, :3] @ ref_grid + src_proj[:, :3, 3:] # b 3 n
|
||||
src_grid = src_proj[:, :3, :3] @ ref_grid + src_proj[:, :3, 3:] # b 3 n
|
||||
div_val = src_grid[:, -1:]
|
||||
div_val[div_val<1e-4] = 1e-4
|
||||
src_grid = src_grid[:, :2] / div_val # divide by depth (b, 2, n)
|
||||
src_grid[:, 0] = src_grid[:, 0]/((length - 1) / 2) - 1 # scale to -1~1
|
||||
src_grid[:, 1] = src_grid[:, 1]/((length - 1) / 2) - 1 # scale to -1~1
|
||||
src_grid = src_grid.permute(0, 2, 1) # (b, n, 2)
|
||||
div_val[div_val < 1e-4] = 1e-4
|
||||
src_grid = src_grid[:, :2] / div_val # divide by depth (b, 2, n)
|
||||
src_grid[:, 0] = src_grid[:, 0] / ((length - 1) / 2) - 1 # scale to -1~1
|
||||
src_grid[:, 1] = src_grid[:, 1] / ((length - 1) / 2) - 1 # scale to -1~1
|
||||
src_grid = src_grid.permute(0, 2, 1) # (b, n, 2)
|
||||
return src_grid
|
||||
|
||||
|
||||
@@ -29,38 +29,55 @@ def construct_project_matrix(x_ratio, y_ratio, Ks, poses):
|
||||
@return:
|
||||
"""
|
||||
rfn = Ks.shape[0]
|
||||
scale_m = torch.tensor([x_ratio, y_ratio, 1.0], dtype=torch.float32, device=Ks.device)
|
||||
scale_m = torch.tensor([x_ratio, y_ratio, 1.0],
|
||||
dtype=torch.float32,
|
||||
device=Ks.device)
|
||||
scale_m = torch.diag(scale_m)
|
||||
ref_prj = scale_m[None, :, :] @ Ks @ poses # rfn,3,4
|
||||
pad_vals = torch.zeros([rfn, 1, 4], dtype=torch.float32, device=ref_prj.device)
|
||||
pad_vals = torch.zeros([rfn, 1, 4],
|
||||
dtype=torch.float32,
|
||||
device=ref_prj.device)
|
||||
pad_vals[:, :, 3] = 1.0
|
||||
ref_prj = torch.cat([ref_prj, pad_vals], 1) # rfn,4,4
|
||||
return ref_prj
|
||||
|
||||
|
||||
def get_warp_coordinates(volume_xyz, warp_size, input_size, Ks, warp_pose):
|
||||
B, _, D, H, W = volume_xyz.shape
|
||||
ratio = warp_size / input_size
|
||||
warp_proj = construct_project_matrix(ratio, ratio, Ks, warp_pose) # B,4,4
|
||||
warp_coords = project_and_normalize(volume_xyz.view(B,3,D*H*W), warp_proj, warp_size).view(B, D, H, W, 2)
|
||||
warp_proj = construct_project_matrix(ratio, ratio, Ks, warp_pose) # B,4,4
|
||||
warp_coords = project_and_normalize(
|
||||
volume_xyz.view(B, 3, D * H * W), warp_proj,
|
||||
warp_size).view(B, D, H, W, 2)
|
||||
return warp_coords
|
||||
|
||||
|
||||
def create_target_volume(depth_size, volume_size, input_image_size, pose_target, K, near=None, far=None):
|
||||
def create_target_volume(depth_size,
|
||||
volume_size,
|
||||
input_image_size,
|
||||
pose_target,
|
||||
K,
|
||||
near=None,
|
||||
far=None):
|
||||
device, dtype = pose_target.device, pose_target.dtype
|
||||
|
||||
# compute a depth range on the unit sphere
|
||||
H, W, D, B = volume_size, volume_size, depth_size, pose_target.shape[0]
|
||||
if near is not None and far is not None :
|
||||
if near is not None and far is not None:
|
||||
# near, far b,1,h,w
|
||||
depth_values = torch.linspace(0, 1, steps=depth_size).to(near.device).to(near.dtype) # d
|
||||
depth_values = depth_values.view(1, D, 1, 1) # 1,d,1,1
|
||||
depth_values = depth_values * (far - near) + near # b d h w
|
||||
depth_values = torch.linspace(
|
||||
0, 1, steps=depth_size).to(near.device).to(near.dtype) # d
|
||||
depth_values = depth_values.view(1, D, 1, 1) # 1,d,1,1
|
||||
depth_values = depth_values * (far - near) + near # b d h w
|
||||
depth_values = depth_values.view(B, 1, D, H * W)
|
||||
else:
|
||||
near, far = near_far_from_unit_sphere_using_camera_poses(pose_target) # b 1
|
||||
depth_values = torch.linspace(0, 1, steps=depth_size).to(near.device).to(near.dtype) # d
|
||||
depth_values = depth_values[None,:,None] * (far[:,None,:] - near[:,None,:]) + near[:,None,:] # b d 1
|
||||
depth_values = depth_values.view(B, 1, D, 1).expand(B, 1, D, H*W)
|
||||
near, far = near_far_from_unit_sphere_using_camera_poses(
|
||||
pose_target) # b 1
|
||||
depth_values = torch.linspace(
|
||||
0, 1, steps=depth_size).to(near.device).to(near.dtype) # d
|
||||
depth_values = depth_values[None, :, None] * (
|
||||
far[:, None, :] - near[:, None, :]) + near[:, None, :] # b d 1
|
||||
depth_values = depth_values.view(B, 1, D, 1).expand(B, 1, D, H * W)
|
||||
|
||||
ratio = volume_size / input_image_size
|
||||
|
||||
@@ -68,20 +85,28 @@ def create_target_volume(depth_size, volume_size, input_image_size, pose_target,
|
||||
# H, W, D, B = volume_size, volume_size, depth_values.shape[1], depth_values.shape[0]
|
||||
|
||||
# creat mesh grid: note reference also means target
|
||||
ref_grid = create_meshgrid(H, W, normalized_coordinates=False) # (1, H, W, 2)
|
||||
ref_grid = create_meshgrid(
|
||||
H, W, normalized_coordinates=False) # (1, H, W, 2)
|
||||
ref_grid = ref_grid.to(device).to(dtype)
|
||||
ref_grid = ref_grid.permute(0, 3, 1, 2) # (1, 2, H, W)
|
||||
ref_grid = ref_grid.reshape(1, 2, H*W) # (1, 2, H*W)
|
||||
ref_grid = ref_grid.expand(B, -1, -1) # (B, 2, H*W)
|
||||
ref_grid = torch.cat((ref_grid, torch.ones(B, 1, H*W, dtype=ref_grid.dtype, device=ref_grid.device)), dim=1) # (B, 3, H*W)
|
||||
ref_grid = ref_grid.permute(0, 3, 1, 2) # (1, 2, H, W)
|
||||
ref_grid = ref_grid.reshape(1, 2, H * W) # (1, 2, H*W)
|
||||
ref_grid = ref_grid.expand(B, -1, -1) # (B, 2, H*W)
|
||||
ref_grid = torch.cat(
|
||||
(ref_grid,
|
||||
torch.ones(B, 1, H * W, dtype=ref_grid.dtype,
|
||||
device=ref_grid.device)),
|
||||
dim=1) # (B, 3, H*W)
|
||||
ref_grid = ref_grid.unsqueeze(2) * depth_values # (B, 3, D, H*W)
|
||||
|
||||
# unproject to space and transfer to world coordinates.
|
||||
Ks = K
|
||||
ref_proj = construct_project_matrix(ratio, ratio, Ks, pose_target) # B,4,4
|
||||
ref_proj_inv = torch.inverse(ref_proj) # B,4,4
|
||||
ref_grid = ref_proj_inv[:,:3,:3] @ ref_grid.view(B,3,D*H*W) + ref_proj_inv[:,:3,3:] # B,3,3 @ B,3,DHW + B,3,1 => B,3,DHW
|
||||
return ref_grid.reshape(B,3,D,H,W), depth_values.view(B,1,D,H,W)
|
||||
ref_proj = construct_project_matrix(ratio, ratio, Ks, pose_target) # B,4,4
|
||||
ref_proj_inv = torch.inverse(ref_proj) # B,4,4
|
||||
ref_grid = ref_proj_inv[:, :3, :3] @ ref_grid.view(
|
||||
B, 3, D * H
|
||||
* W) + ref_proj_inv[:, :3, 3:] # B,3,3 @ B,3,DHW + B,3,1 => B,3,DHW
|
||||
return ref_grid.reshape(B, 3, D, H, W), depth_values.view(B, 1, D, H, W)
|
||||
|
||||
|
||||
def near_far_from_unit_sphere_using_camera_poses(camera_poses):
|
||||
"""
|
||||
@@ -90,14 +115,16 @@ def near_far_from_unit_sphere_using_camera_poses(camera_poses):
|
||||
near: b,1
|
||||
far: b,1
|
||||
"""
|
||||
R_w2c = camera_poses[..., :3, :3] # b 3 3
|
||||
t_w2c = camera_poses[..., :3, 3:] # b 3 1
|
||||
camera_origin = -R_w2c.permute(0,2,1) @ t_w2c # b 3 1
|
||||
R_w2c = camera_poses[..., :3, :3] # b 3 3
|
||||
t_w2c = camera_poses[..., :3, 3:] # b 3 1
|
||||
camera_origin = -R_w2c.permute(0, 2, 1) @ t_w2c # b 3 1
|
||||
# R_w2c.T @ (0,0,1) = z_dir
|
||||
camera_orient = R_w2c.permute(0,2,1)[...,:3,2:3] # b 3 1
|
||||
camera_origin, camera_orient = camera_origin[...,0], camera_orient[..., 0] # b 3
|
||||
a = torch.sum(camera_orient ** 2, dim=-1, keepdim=True) # b 1
|
||||
b = -torch.sum(camera_orient * camera_origin, dim=-1, keepdim=True) # b 1
|
||||
mid = b / a # b 1
|
||||
camera_orient = R_w2c.permute(0, 2, 1)[..., :3, 2:3] # b 3 1
|
||||
camera_origin, camera_orient = camera_origin[...,
|
||||
0], camera_orient[...,
|
||||
0] # b 3
|
||||
a = torch.sum(camera_orient**2, dim=-1, keepdim=True) # b 1
|
||||
b = -torch.sum(camera_orient * camera_origin, dim=-1, keepdim=True) # b 1
|
||||
mid = b / a # b 1
|
||||
near, far = mid - 1.0, mid + 1.0
|
||||
return near, far
|
||||
return near, far
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
from inspect import isfunction
|
||||
import math
|
||||
from inspect import isfunction
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
from einops import rearrange, repeat
|
||||
from torch import einsum, nn
|
||||
|
||||
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.util import checkpoint
|
||||
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.util import \
|
||||
checkpoint
|
||||
|
||||
|
||||
def exists(val):
|
||||
@@ -13,7 +15,7 @@ def exists(val):
|
||||
|
||||
|
||||
def uniq(arr):
|
||||
return{el: True for el in arr}.keys()
|
||||
return {el: True for el in arr}.keys()
|
||||
|
||||
|
||||
def default(val, d):
|
||||
@@ -35,6 +37,7 @@ def init_(tensor):
|
||||
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
@@ -42,8 +45,11 @@ class GEGLU(nn.Module):
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * F.gelu(gate)
|
||||
|
||||
|
||||
# feedforward
|
||||
class ConvGEGLU(nn.Module):
|
||||
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Conv2d(dim_in, dim_out * 2, 1, 1, 0)
|
||||
@@ -54,20 +60,16 @@ class ConvGEGLU(nn.Module):
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(
|
||||
nn.Linear(dim, inner_dim),
|
||||
nn.GELU()
|
||||
) if not glu else GEGLU(dim, inner_dim)
|
||||
project_in = nn.Sequential(nn.Linear(
|
||||
dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in,
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim_out)
|
||||
)
|
||||
self.net = nn.Sequential(project_in, nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim_out))
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
@@ -83,54 +85,54 @@ def zero_module(module):
|
||||
|
||||
|
||||
def Normalize(in_channels):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
return torch.nn.GroupNorm(
|
||||
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class LinearAttention(nn.Module):
|
||||
|
||||
def __init__(self, dim, heads=4, dim_head=32):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
hidden_dim = dim_head * heads
|
||||
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
|
||||
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
||||
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
qkv = self.to_qkv(x)
|
||||
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
|
||||
k = k.softmax(dim=-1)
|
||||
q, k, v = rearrange(
|
||||
qkv,
|
||||
'b (qkv heads c) h w -> qkv b heads c (h w)',
|
||||
heads=self.heads,
|
||||
qkv=3)
|
||||
k = k.softmax(dim=-1)
|
||||
context = torch.einsum('bhdn,bhen->bhde', k, v)
|
||||
out = torch.einsum('bhde,bhdn->bhen', context, q)
|
||||
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
|
||||
out = rearrange(
|
||||
out,
|
||||
'b heads c (h w) -> b (heads c) h w',
|
||||
heads=self.heads,
|
||||
h=h,
|
||||
w=w)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class SpatialSelfAttention(nn.Module):
|
||||
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.q = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
@@ -140,7 +142,7 @@ class SpatialSelfAttention(nn.Module):
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b,c,h,w = q.shape
|
||||
b, c, h, w = q.shape
|
||||
q = rearrange(q, 'b c h w -> b (h w) c')
|
||||
k = rearrange(k, 'b c h w -> b c (h w)')
|
||||
w_ = torch.einsum('bij,bjk->bik', q, k)
|
||||
@@ -155,16 +157,22 @@ class SpatialSelfAttention(nn.Module):
|
||||
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x+h_
|
||||
return x + h_
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
||||
|
||||
def __init__(self,
|
||||
query_dim,
|
||||
context_dim=None,
|
||||
heads=8,
|
||||
dim_head=64,
|
||||
dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.scale = dim_head ** -0.5
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = heads
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
@@ -172,9 +180,7 @@ class CrossAttention(nn.Module):
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, query_dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
@@ -184,12 +190,13 @@ class CrossAttention(nn.Module):
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||
(q, k, v))
|
||||
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
|
||||
if exists(mask):
|
||||
mask = mask>0
|
||||
mask = mask > 0
|
||||
mask = rearrange(mask, 'b ... -> b (...)')
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||
@@ -202,8 +209,15 @@ class CrossAttention(nn.Module):
|
||||
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class BasicSpatialTransformer(nn.Module):
|
||||
def __init__(self, dim, n_heads, d_head, context_dim=None, checkpoint=True):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
context_dim=None,
|
||||
checkpoint=True):
|
||||
super().__init__()
|
||||
inner_dim = n_heads * d_head
|
||||
self.proj_in = nn.Sequential(
|
||||
@@ -212,7 +226,12 @@ class BasicSpatialTransformer(nn.Module):
|
||||
nn.GroupNorm(8, inner_dim),
|
||||
nn.ReLU(True),
|
||||
)
|
||||
self.attn = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, context_dim=context_dim) # is a self-attention if not self.disable_self_attn
|
||||
self.attn = CrossAttention(
|
||||
query_dim=inner_dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
context_dim=context_dim
|
||||
) # is a self-attention if not self.disable_self_attn
|
||||
self.out_conv = nn.Sequential(
|
||||
nn.GroupNorm(8, inner_dim),
|
||||
nn.ReLU(True),
|
||||
@@ -221,16 +240,18 @@ class BasicSpatialTransformer(nn.Module):
|
||||
self.proj_out = nn.Sequential(
|
||||
nn.GroupNorm(8, inner_dim),
|
||||
nn.ReLU(True),
|
||||
zero_module(nn.Conv2d(inner_dim, dim, kernel_size=1, stride=1, padding=0)),
|
||||
zero_module(
|
||||
nn.Conv2d(inner_dim, dim, kernel_size=1, stride=1, padding=0)),
|
||||
)
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def forward(self, x, context=None):
|
||||
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
||||
return checkpoint(self._forward, (x, context), self.parameters(),
|
||||
self.checkpoint)
|
||||
|
||||
def _forward(self, x, context):
|
||||
# input
|
||||
b,_,h,w = x.shape
|
||||
b, _, h, w = x.shape
|
||||
x_in = x
|
||||
x = self.proj_in(x)
|
||||
|
||||
@@ -245,44 +266,64 @@ class BasicSpatialTransformer(nn.Module):
|
||||
x = self.proj_out(x) + x_in
|
||||
return x
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, disable_self_attn=False):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
dropout=0.,
|
||||
context_dim=None,
|
||||
gated_ff=True,
|
||||
checkpoint=True,
|
||||
disable_self_attn=False):
|
||||
super().__init__()
|
||||
self.disable_self_attn = disable_self_attn
|
||||
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
||||
context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
|
||||
self.attn1 = CrossAttention(
|
||||
query_dim=dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
dropout=dropout,
|
||||
context_dim=context_dim if self.disable_self_attn else
|
||||
None) # is a self-attention if not self.disable_self_attn
|
||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
||||
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
|
||||
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
|
||||
self.attn2 = CrossAttention(
|
||||
query_dim=dim,
|
||||
context_dim=context_dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
dropout=dropout) # is self-attn if context is none
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def forward(self, x, context=None):
|
||||
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
||||
return checkpoint(self._forward, (x, context), self.parameters(),
|
||||
self.checkpoint)
|
||||
|
||||
def _forward(self, x, context=None):
|
||||
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
|
||||
x = self.attn1(
|
||||
self.norm1(x),
|
||||
context=context if self.disable_self_attn else None) + x
|
||||
x = self.attn2(self.norm2(x), context=context) + x
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
return x
|
||||
|
||||
|
||||
class ConvFeedForward(nn.Module):
|
||||
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(
|
||||
nn.Conv2d(dim, inner_dim, 1, 1, 0),
|
||||
nn.GELU()
|
||||
) if not glu else ConvGEGLU(dim, inner_dim)
|
||||
nn.GELU()) if not glu else ConvGEGLU(dim, inner_dim)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in,
|
||||
nn.Dropout(dropout),
|
||||
nn.Conv2d(inner_dim, dim_out, 1, 1, 0)
|
||||
)
|
||||
self.net = nn.Sequential(project_in, nn.Dropout(dropout),
|
||||
nn.Conv2d(inner_dim, dim_out, 1, 1, 0))
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
@@ -296,31 +337,36 @@ class SpatialTransformer(nn.Module):
|
||||
Then apply standard transformer action.
|
||||
Finally, reshape to image
|
||||
"""
|
||||
def __init__(self, in_channels, n_heads, d_head,
|
||||
depth=1, dropout=0., context_dim=None,
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
n_heads,
|
||||
d_head,
|
||||
depth=1,
|
||||
dropout=0.,
|
||||
context_dim=None,
|
||||
disable_self_attn=False):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = Normalize(in_channels)
|
||||
|
||||
self.proj_in = nn.Conv2d(in_channels,
|
||||
inner_dim,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_in = nn.Conv2d(
|
||||
in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim,
|
||||
disable_self_attn=disable_self_attn)
|
||||
for d in range(depth)]
|
||||
)
|
||||
self.transformer_blocks = nn.ModuleList([
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
dropout=dropout,
|
||||
context_dim=context_dim,
|
||||
disable_self_attn=disable_self_attn) for d in range(depth)
|
||||
])
|
||||
|
||||
self.proj_out = zero_module(nn.Conv2d(inner_dim,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0))
|
||||
self.proj_out = zero_module(
|
||||
nn.Conv2d(
|
||||
inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
|
||||
|
||||
def forward(self, x, context=None):
|
||||
# note: if no context is given, cross-attention defaults to self-attention
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,6 @@
|
||||
import math
|
||||
from abc import abstractmethod
|
||||
from functools import partial
|
||||
import math
|
||||
from typing import Iterable
|
||||
|
||||
import numpy as np
|
||||
@@ -8,16 +8,11 @@ import torch as th
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modelscope.models.cv.image_to_3d.ldm.modules.attention import \
|
||||
SpatialTransformer
|
||||
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.util import (
|
||||
checkpoint,
|
||||
conv_nd,
|
||||
linear,
|
||||
avg_pool_nd,
|
||||
zero_module,
|
||||
normalization,
|
||||
timestep_embedding,
|
||||
)
|
||||
from modelscope.models.cv.image_to_3d.ldm.modules.attention import SpatialTransformer
|
||||
avg_pool_nd, checkpoint, conv_nd, linear, normalization,
|
||||
timestep_embedding, zero_module)
|
||||
from modelscope.models.cv.image_to_3d.ldm.util import exists
|
||||
|
||||
|
||||
@@ -25,11 +20,11 @@ from modelscope.models.cv.image_to_3d.ldm.util import exists
|
||||
def convert_module_to_f16(x):
|
||||
pass
|
||||
|
||||
|
||||
def convert_module_to_f32(x):
|
||||
pass
|
||||
|
||||
|
||||
## go
|
||||
class AttentionPool2d(nn.Module):
|
||||
"""
|
||||
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
|
||||
@@ -43,7 +38,8 @@ class AttentionPool2d(nn.Module):
|
||||
output_dim: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
|
||||
self.positional_embedding = nn.Parameter(
|
||||
th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5)
|
||||
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
|
||||
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
|
||||
self.num_heads = embed_dim // num_heads_channels
|
||||
@@ -98,37 +94,46 @@ class Upsample(nn.Module):
|
||||
upsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
||||
def __init__(self,
|
||||
channels,
|
||||
use_conv,
|
||||
dims=2,
|
||||
out_channels=None,
|
||||
padding=1):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
if use_conv:
|
||||
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
|
||||
self.conv = conv_nd(
|
||||
dims, self.channels, self.out_channels, 3, padding=padding)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
if self.dims == 3:
|
||||
x = F.interpolate(
|
||||
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
|
||||
)
|
||||
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
|
||||
mode='nearest')
|
||||
else:
|
||||
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
if self.use_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class TransposedUpsample(nn.Module):
|
||||
'Learned 2x upsampling without padding'
|
||||
|
||||
def __init__(self, channels, out_channels=None, ks=5):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
|
||||
self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
|
||||
self.up = nn.ConvTranspose2d(
|
||||
self.channels, self.out_channels, kernel_size=ks, stride=2)
|
||||
|
||||
def forward(self,x):
|
||||
def forward(self, x):
|
||||
return self.up(x)
|
||||
|
||||
|
||||
@@ -141,7 +146,12 @@ class Downsample(nn.Module):
|
||||
downsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
|
||||
def __init__(self,
|
||||
channels,
|
||||
use_conv,
|
||||
dims=2,
|
||||
out_channels=None,
|
||||
padding=1):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
@@ -150,8 +160,12 @@ class Downsample(nn.Module):
|
||||
stride = 2 if dims != 3 else (1, 2, 2)
|
||||
if use_conv:
|
||||
self.op = conv_nd(
|
||||
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
|
||||
)
|
||||
dims,
|
||||
self.channels,
|
||||
self.out_channels,
|
||||
3,
|
||||
stride=stride,
|
||||
padding=padding)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
||||
@@ -220,7 +234,8 @@ class ResBlock(TimestepBlock):
|
||||
nn.SiLU(),
|
||||
linear(
|
||||
emb_channels,
|
||||
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
||||
2 * self.out_channels
|
||||
if use_scale_shift_norm else self.out_channels,
|
||||
),
|
||||
)
|
||||
self.out_layers = nn.Sequential(
|
||||
@@ -228,18 +243,18 @@ class ResBlock(TimestepBlock):
|
||||
nn.SiLU(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(
|
||||
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
|
||||
),
|
||||
conv_nd(
|
||||
dims, self.out_channels, self.out_channels, 3, padding=1)),
|
||||
)
|
||||
|
||||
if self.out_channels == channels:
|
||||
self.skip_connection = nn.Identity()
|
||||
elif use_conv:
|
||||
self.skip_connection = conv_nd(
|
||||
dims, channels, self.out_channels, 3, padding=1
|
||||
)
|
||||
dims, channels, self.out_channels, 3, padding=1)
|
||||
else:
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels,
|
||||
1)
|
||||
|
||||
def forward(self, x, emb):
|
||||
"""
|
||||
@@ -248,10 +263,8 @@ class ResBlock(TimestepBlock):
|
||||
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
return checkpoint(
|
||||
self._forward, (x, emb), self.parameters(), self.use_checkpoint
|
||||
)
|
||||
|
||||
return checkpoint(self._forward, (x, emb), self.parameters(),
|
||||
self.use_checkpoint)
|
||||
|
||||
def _forward(self, x, emb):
|
||||
if self.updown:
|
||||
@@ -265,7 +278,7 @@ class ResBlock(TimestepBlock):
|
||||
emb_out = self.emb_layers(emb).type(h.dtype)
|
||||
while len(emb_out.shape) < len(h.shape):
|
||||
emb_out = emb_out[..., None]
|
||||
if self.use_scale_shift_norm: # False
|
||||
if self.use_scale_shift_norm: # False
|
||||
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
||||
scale, shift = th.chunk(emb_out, 2, dim=1)
|
||||
h = out_norm(h) * (1 + scale) + shift
|
||||
@@ -298,7 +311,7 @@ class AttentionBlock(nn.Module):
|
||||
else:
|
||||
assert (
|
||||
channels % num_head_channels == 0
|
||||
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
||||
), f'q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}'
|
||||
self.num_heads = channels // num_head_channels
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.norm = normalization(channels)
|
||||
@@ -313,8 +326,10 @@ class AttentionBlock(nn.Module):
|
||||
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
||||
|
||||
def forward(self, x):
|
||||
return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
|
||||
#return pt_checkpoint(self._forward, x) # pytorch
|
||||
return checkpoint(
|
||||
self._forward, (x, ), self.parameters(), True
|
||||
) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
|
||||
# return pt_checkpoint(self._forward, x) # pytorch
|
||||
|
||||
def _forward(self, x):
|
||||
b, c, *spatial = x.shape
|
||||
@@ -341,7 +356,7 @@ def count_flops_attn(model, _x, y):
|
||||
# We perform two matmuls with the same number of ops.
|
||||
# The first computes the weight matrix, the second computes
|
||||
# the combination of the value vectors.
|
||||
matmul_ops = 2 * b * (num_spatial ** 2) * c
|
||||
matmul_ops = 2 * b * (num_spatial**2) * c
|
||||
model.total_ops += th.DoubleTensor([matmul_ops])
|
||||
|
||||
|
||||
@@ -363,13 +378,14 @@ class QKVAttentionLegacy(nn.Module):
|
||||
bs, width, length = qkv.shape
|
||||
assert width % (3 * self.n_heads) == 0
|
||||
ch = width // (3 * self.n_heads)
|
||||
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
||||
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(
|
||||
ch, dim=1)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = th.einsum(
|
||||
"bct,bcs->bts", q * scale, k * scale
|
||||
) # More stable with f16 than dividing afterwards
|
||||
'bct,bcs->bts', q * scale,
|
||||
k * scale) # More stable with f16 than dividing afterwards
|
||||
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
a = th.einsum("bts,bcs->bct", weight, v)
|
||||
a = th.einsum('bts,bcs->bct', weight, v)
|
||||
return a.reshape(bs, -1, length)
|
||||
|
||||
@staticmethod
|
||||
@@ -398,12 +414,13 @@ class QKVAttention(nn.Module):
|
||||
q, k, v = qkv.chunk(3, dim=1)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = th.einsum(
|
||||
"bct,bcs->bts",
|
||||
'bct,bcs->bts',
|
||||
(q * scale).view(bs * self.n_heads, ch, length),
|
||||
(k * scale).view(bs * self.n_heads, ch, length),
|
||||
) # More stable with f16 than dividing afterwards
|
||||
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
|
||||
a = th.einsum('bts,bcs->bct', weight,
|
||||
v.reshape(bs * self.n_heads, ch, length))
|
||||
return a.reshape(bs, -1, length)
|
||||
|
||||
@staticmethod
|
||||
@@ -442,40 +459,43 @@ class UNetModel(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size,
|
||||
in_channels,
|
||||
model_channels,
|
||||
out_channels,
|
||||
num_res_blocks,
|
||||
attention_resolutions,
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
conv_resample=True,
|
||||
dims=2,
|
||||
num_classes=None,
|
||||
use_checkpoint=False,
|
||||
use_fp16=False,
|
||||
num_heads=-1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
use_scale_shift_norm=False,
|
||||
resblock_updown=False,
|
||||
use_new_attention_order=False,
|
||||
use_spatial_transformer=False, # custom transformer support
|
||||
transformer_depth=1, # custom transformer support
|
||||
context_dim=None, # custom transformer support
|
||||
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
||||
legacy=True,
|
||||
disable_self_attentions=None,
|
||||
num_attention_blocks=None
|
||||
):
|
||||
self,
|
||||
image_size,
|
||||
in_channels,
|
||||
model_channels,
|
||||
out_channels,
|
||||
num_res_blocks,
|
||||
attention_resolutions,
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
conv_resample=True,
|
||||
dims=2,
|
||||
num_classes=None,
|
||||
use_checkpoint=False,
|
||||
use_fp16=False,
|
||||
num_heads=-1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
use_scale_shift_norm=False,
|
||||
resblock_updown=False,
|
||||
use_new_attention_order=False,
|
||||
use_spatial_transformer=False, # custom transformer support
|
||||
transformer_depth=1, # custom transformer support
|
||||
context_dim=None, # custom transformer support
|
||||
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
||||
legacy=True,
|
||||
disable_self_attentions=None,
|
||||
num_attention_blocks=None):
|
||||
super().__init__()
|
||||
if use_spatial_transformer:
|
||||
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
||||
assert context_dim is not None, (
|
||||
'Fool!! You forgot to include the dimension '
|
||||
'of your cross-attention conditioning...')
|
||||
|
||||
if context_dim is not None:
|
||||
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
||||
assert use_spatial_transformer, (
|
||||
'Fool!! You forgot to use the spatial transformer '
|
||||
'for your cross-attention conditioning...')
|
||||
from omegaconf.listconfig import ListConfig
|
||||
if type(context_dim) == ListConfig:
|
||||
context_dim = list(context_dim)
|
||||
@@ -497,20 +517,28 @@ class UNetModel(nn.Module):
|
||||
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
||||
else:
|
||||
if len(num_res_blocks) != len(channel_mult):
|
||||
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
||||
"as a list/tuple (per-level) with the same length as channel_mult")
|
||||
raise ValueError(
|
||||
'provide num_res_blocks either as an int (globally constant) or '
|
||||
'as a list/tuple (per-level) with the same length as channel_mult'
|
||||
)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
#self.num_res_blocks = num_res_blocks
|
||||
# self.num_res_blocks = num_res_blocks
|
||||
if disable_self_attentions is not None:
|
||||
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
||||
assert len(disable_self_attentions) == len(channel_mult)
|
||||
if num_attention_blocks is not None:
|
||||
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
||||
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
||||
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
||||
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
||||
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
||||
f"attention will still not be set.") # todo: convert to warning
|
||||
assert all(
|
||||
map(
|
||||
lambda i: self.num_res_blocks[i] >= num_attention_blocks[i
|
||||
],
|
||||
range(len(num_attention_blocks))))
|
||||
print(
|
||||
f'Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. '
|
||||
f'This option has LESS priority than attention_resolutions {attention_resolutions}, '
|
||||
f'i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, '
|
||||
f'attention will still not be set.'
|
||||
) # todo: convert to warning
|
||||
|
||||
self.attention_resolutions = attention_resolutions
|
||||
self.dropout = dropout
|
||||
@@ -534,13 +562,10 @@ class UNetModel(nn.Module):
|
||||
if self.num_classes is not None:
|
||||
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
||||
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
||||
)
|
||||
]
|
||||
) # 0
|
||||
self.input_blocks = nn.ModuleList([
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(dims, in_channels, model_channels, 3, padding=1))
|
||||
]) # 0
|
||||
self._feature_size = model_channels
|
||||
input_block_chans = [model_channels]
|
||||
ch = model_channels
|
||||
@@ -559,21 +584,22 @@ class UNetModel(nn.Module):
|
||||
)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
if ds in attention_resolutions: # always True
|
||||
if ds in attention_resolutions: # always True
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
#num_heads = 1
|
||||
# num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
if exists(disable_self_attentions):
|
||||
disabled_sa = disable_self_attentions[level]
|
||||
else:
|
||||
disabled_sa = False
|
||||
|
||||
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
||||
if not exists(num_attention_blocks
|
||||
) or nr < num_attention_blocks[level]:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
@@ -581,11 +607,14 @@ class UNetModel(nn.Module):
|
||||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa
|
||||
)
|
||||
)
|
||||
) if not use_spatial_transformer else
|
||||
SpatialTransformer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa))
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
input_block_chans.append(ch)
|
||||
@@ -602,12 +631,8 @@ class UNetModel(nn.Module):
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Downsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch
|
||||
)
|
||||
)
|
||||
) if resblock_updown else Downsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch))
|
||||
)
|
||||
ch = out_ch
|
||||
input_block_chans.append(ch)
|
||||
@@ -620,7 +645,7 @@ class UNetModel(nn.Module):
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
#num_heads = 1
|
||||
# num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
@@ -637,9 +662,13 @@ class UNetModel(nn.Module):
|
||||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
|
||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
|
||||
),
|
||||
) if not use_spatial_transformer else
|
||||
SpatialTransformer( # always uses a self-attn
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim),
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
@@ -674,14 +703,15 @@ class UNetModel(nn.Module):
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
#num_heads = 1
|
||||
# num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
if exists(disable_self_attentions):
|
||||
disabled_sa = disable_self_attentions[level]
|
||||
else:
|
||||
disabled_sa = False
|
||||
|
||||
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
|
||||
if not exists(num_attention_blocks
|
||||
) or i < num_attention_blocks[level]:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
@@ -689,11 +719,14 @@ class UNetModel(nn.Module):
|
||||
num_heads=num_heads_upsample,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa
|
||||
)
|
||||
)
|
||||
) if not use_spatial_transformer else
|
||||
SpatialTransformer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa))
|
||||
if level and i == self.num_res_blocks[level]:
|
||||
out_ch = ch
|
||||
layers.append(
|
||||
@@ -706,10 +739,8 @@ class UNetModel(nn.Module):
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
up=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
||||
)
|
||||
) if resblock_updown else Upsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch))
|
||||
ds //= 2
|
||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
@@ -717,14 +748,15 @@ class UNetModel(nn.Module):
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
||||
zero_module(
|
||||
conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
||||
)
|
||||
if self.predict_codebook_ids:
|
||||
self.id_predictor = nn.Sequential(
|
||||
normalization(ch),
|
||||
conv_nd(dims, model_channels, n_embed, 1),
|
||||
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
||||
)
|
||||
normalization(ch),
|
||||
conv_nd(dims, model_channels, n_embed, 1),
|
||||
# nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
||||
)
|
||||
|
||||
def convert_to_fp16(self):
|
||||
"""
|
||||
@@ -742,7 +774,7 @@ class UNetModel(nn.Module):
|
||||
self.middle_block.apply(convert_module_to_f32)
|
||||
self.output_blocks.apply(convert_module_to_f32)
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
|
||||
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
@@ -753,18 +785,19 @@ class UNetModel(nn.Module):
|
||||
"""
|
||||
assert (y is not None) == (
|
||||
self.num_classes is not None
|
||||
), "must specify y if and only if the model is class-conditional"
|
||||
), 'must specify y if and only if the model is class-conditional'
|
||||
hs = []
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) # N
|
||||
emb = self.time_embed(t_emb) #
|
||||
t_emb = timestep_embedding(
|
||||
timesteps, self.model_channels, repeat_only=False) # N
|
||||
emb = self.time_embed(t_emb) #
|
||||
|
||||
if self.num_classes is not None:
|
||||
assert y.shape == (x.shape[0],)
|
||||
assert y.shape == (x.shape[0], )
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
h = x.type(self.dtype)
|
||||
for module in self.input_blocks:
|
||||
h = module(h, emb, context) # conv
|
||||
h = module(h, emb, context) # conv
|
||||
hs.append(h)
|
||||
h = self.middle_block(h, emb, context)
|
||||
for module in self.output_blocks:
|
||||
@@ -783,30 +816,28 @@ class EncoderUNetModel(nn.Module):
|
||||
For usage, see UNet.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size,
|
||||
in_channels,
|
||||
model_channels,
|
||||
out_channels,
|
||||
num_res_blocks,
|
||||
attention_resolutions,
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
conv_resample=True,
|
||||
dims=2,
|
||||
use_checkpoint=False,
|
||||
use_fp16=False,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
use_scale_shift_norm=False,
|
||||
resblock_updown=False,
|
||||
use_new_attention_order=False,
|
||||
pool="adaptive",
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
def __init__(self,
|
||||
image_size,
|
||||
in_channels,
|
||||
model_channels,
|
||||
out_channels,
|
||||
num_res_blocks,
|
||||
attention_resolutions,
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
conv_resample=True,
|
||||
dims=2,
|
||||
use_checkpoint=False,
|
||||
use_fp16=False,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
use_scale_shift_norm=False,
|
||||
resblock_updown=False,
|
||||
use_new_attention_order=False,
|
||||
pool='adaptive',
|
||||
*args,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
if num_heads_upsample == -1:
|
||||
@@ -833,13 +864,10 @@ class EncoderUNetModel(nn.Module):
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
||||
)
|
||||
]
|
||||
)
|
||||
self.input_blocks = nn.ModuleList([
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(dims, in_channels, model_channels, 3, padding=1))
|
||||
])
|
||||
self._feature_size = model_channels
|
||||
input_block_chans = [model_channels]
|
||||
ch = model_channels
|
||||
@@ -866,8 +894,7 @@ class EncoderUNetModel(nn.Module):
|
||||
num_heads=num_heads,
|
||||
num_head_channels=num_head_channels,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
)
|
||||
)
|
||||
))
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
input_block_chans.append(ch)
|
||||
@@ -884,12 +911,8 @@ class EncoderUNetModel(nn.Module):
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Downsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch
|
||||
)
|
||||
)
|
||||
) if resblock_updown else Downsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch))
|
||||
)
|
||||
ch = out_ch
|
||||
input_block_chans.append(ch)
|
||||
@@ -923,7 +946,7 @@ class EncoderUNetModel(nn.Module):
|
||||
)
|
||||
self._feature_size += ch
|
||||
self.pool = pool
|
||||
if pool == "adaptive":
|
||||
if pool == 'adaptive':
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
@@ -931,22 +954,21 @@ class EncoderUNetModel(nn.Module):
|
||||
zero_module(conv_nd(dims, ch, out_channels, 1)),
|
||||
nn.Flatten(),
|
||||
)
|
||||
elif pool == "attention":
|
||||
elif pool == 'attention':
|
||||
assert num_head_channels != -1
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
AttentionPool2d(
|
||||
(image_size // ds), ch, num_head_channels, out_channels
|
||||
),
|
||||
AttentionPool2d((image_size // ds), ch, num_head_channels,
|
||||
out_channels),
|
||||
)
|
||||
elif pool == "spatial":
|
||||
elif pool == 'spatial':
|
||||
self.out = nn.Sequential(
|
||||
nn.Linear(self._feature_size, 2048),
|
||||
nn.ReLU(),
|
||||
nn.Linear(2048, self.out_channels),
|
||||
)
|
||||
elif pool == "spatial_v2":
|
||||
elif pool == 'spatial_v2':
|
||||
self.out = nn.Sequential(
|
||||
nn.Linear(self._feature_size, 2048),
|
||||
normalization(2048),
|
||||
@@ -954,7 +976,7 @@ class EncoderUNetModel(nn.Module):
|
||||
nn.Linear(2048, self.out_channels),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unexpected {pool} pooling")
|
||||
raise NotImplementedError(f'Unexpected {pool} pooling')
|
||||
|
||||
def convert_to_fp16(self):
|
||||
"""
|
||||
@@ -977,20 +999,20 @@ class EncoderUNetModel(nn.Module):
|
||||
:param timesteps: a 1-D batch of timesteps.
|
||||
:return: an [N x K] Tensor of outputs.
|
||||
"""
|
||||
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
emb = self.time_embed(
|
||||
timestep_embedding(timesteps, self.model_channels))
|
||||
|
||||
results = []
|
||||
h = x.type(self.dtype)
|
||||
for module in self.input_blocks:
|
||||
h = module(h, emb)
|
||||
if self.pool.startswith("spatial"):
|
||||
if self.pool.startswith('spatial'):
|
||||
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
||||
h = self.middle_block(h, emb)
|
||||
if self.pool.startswith("spatial"):
|
||||
if self.pool.startswith('spatial'):
|
||||
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
||||
h = th.cat(results, axis=-1)
|
||||
return self.out(h)
|
||||
else:
|
||||
h = h.type(x.dtype)
|
||||
return self.out(h)
|
||||
|
||||
|
||||
@@ -7,50 +7,65 @@
|
||||
#
|
||||
# thanks!
|
||||
|
||||
|
||||
import os
|
||||
import math
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from einops import repeat
|
||||
|
||||
from modelscope.models.cv.image_to_3d.ldm.util import instantiate_from_config
|
||||
|
||||
|
||||
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
if schedule == "linear":
|
||||
def make_beta_schedule(schedule,
|
||||
n_timestep,
|
||||
linear_start=1e-4,
|
||||
linear_end=2e-2,
|
||||
cosine_s=8e-3):
|
||||
if schedule == 'linear':
|
||||
betas = (
|
||||
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
|
||||
)
|
||||
torch.linspace(
|
||||
linear_start**0.5,
|
||||
linear_end**0.5,
|
||||
n_timestep,
|
||||
dtype=torch.float64)**2)
|
||||
|
||||
elif schedule == "cosine":
|
||||
elif schedule == 'cosine':
|
||||
timesteps = (
|
||||
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
|
||||
)
|
||||
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep
|
||||
+ cosine_s)
|
||||
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
||||
alphas = torch.cos(alphas).pow(2)
|
||||
alphas = alphas / alphas[0]
|
||||
betas = 1 - alphas[1:] / alphas[:-1]
|
||||
betas = np.clip(betas, a_min=0, a_max=0.999)
|
||||
|
||||
elif schedule == "sqrt_linear":
|
||||
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
||||
elif schedule == "sqrt":
|
||||
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
|
||||
elif schedule == 'sqrt_linear':
|
||||
betas = torch.linspace(
|
||||
linear_start, linear_end, n_timestep, dtype=torch.float64)
|
||||
elif schedule == 'sqrt':
|
||||
betas = torch.linspace(
|
||||
linear_start, linear_end, n_timestep, dtype=torch.float64)**0.5
|
||||
else:
|
||||
raise ValueError(f"schedule '{schedule}' unknown.")
|
||||
return betas.numpy()
|
||||
|
||||
|
||||
def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
|
||||
def make_ddim_timesteps(ddim_discr_method,
|
||||
num_ddim_timesteps,
|
||||
num_ddpm_timesteps,
|
||||
verbose=True):
|
||||
if ddim_discr_method == 'uniform':
|
||||
c = num_ddpm_timesteps // num_ddim_timesteps
|
||||
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
|
||||
elif ddim_discr_method == 'quad':
|
||||
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
|
||||
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8),
|
||||
num_ddim_timesteps))**2).astype(int)
|
||||
else:
|
||||
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
|
||||
raise NotImplementedError(
|
||||
f'There is no ddim discretization method called "{ddim_discr_method}"'
|
||||
)
|
||||
|
||||
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
|
||||
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
||||
@@ -60,17 +75,27 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep
|
||||
return steps_out
|
||||
|
||||
|
||||
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
|
||||
def make_ddim_sampling_parameters(alphacums,
|
||||
ddim_timesteps,
|
||||
eta,
|
||||
verbose=True):
|
||||
# select alphas for computing the variance schedule
|
||||
alphas = alphacums[ddim_timesteps]
|
||||
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
|
||||
alphas_prev = np.asarray([alphacums[0]]
|
||||
+ alphacums[ddim_timesteps[:-1]].tolist())
|
||||
|
||||
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
||||
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
|
||||
# rewrite because of E125
|
||||
tmp = (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
|
||||
sigmas = (eta * np.sqrt(tmp))
|
||||
if verbose:
|
||||
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
|
||||
print(f'For the chosen value of eta, which is {eta}, '
|
||||
f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
|
||||
print(
|
||||
f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}'
|
||||
)
|
||||
print(
|
||||
f'For the chosen value of eta, which is {eta}, '
|
||||
f'this results in the following sigma_t schedule for ddim sampler {sigmas}'
|
||||
)
|
||||
return sigmas, alphas, alphas_prev
|
||||
|
||||
|
||||
@@ -96,7 +121,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
||||
def extract_into_tensor(a, t, x_shape):
|
||||
b, *_ = t.shape
|
||||
out = a.gather(-1, t)
|
||||
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
||||
return out.reshape(b, *((1, ) * (len(x_shape) - 1)))
|
||||
|
||||
|
||||
def checkpoint(func, inputs, params, flag):
|
||||
@@ -117,6 +142,7 @@ def checkpoint(func, inputs, params, flag):
|
||||
|
||||
|
||||
class CheckpointFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, run_function, length, *args):
|
||||
ctx.run_function = run_function
|
||||
@@ -129,7 +155,9 @@ class CheckpointFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *output_grads):
|
||||
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
||||
ctx.input_tensors = [
|
||||
x.detach().requires_grad_(True) for x in ctx.input_tensors
|
||||
]
|
||||
with torch.enable_grad():
|
||||
# Fixes a bug where the first op in run_function modifies the
|
||||
# Tensor storage in place, which is not allowed for detach()'d
|
||||
@@ -160,12 +188,14 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
||||
if not repeat_only:
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
||||
).to(device=timesteps.device)
|
||||
-math.log(max_period)
|
||||
* torch.arange(start=0, end=half, dtype=torch.float32)
|
||||
/ half).to(device=timesteps.device)
|
||||
args = timesteps[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
embedding = torch.cat(
|
||||
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
else:
|
||||
embedding = repeat(timesteps, 'b -> b d', d=dim)
|
||||
return embedding
|
||||
@@ -207,14 +237,17 @@ def normalization(channels):
|
||||
|
||||
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
||||
class SiLU(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class GroupNorm32(nn.GroupNorm):
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D convolution module.
|
||||
@@ -225,7 +258,7 @@ def conv_nd(dims, *args, **kwargs):
|
||||
return nn.Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.Conv3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
raise ValueError(f'unsupported dimensions: {dims}')
|
||||
|
||||
|
||||
def linear(*args, **kwargs):
|
||||
@@ -245,7 +278,7 @@ def avg_pool_nd(dims, *args, **kwargs):
|
||||
return nn.AvgPool2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.AvgPool3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
raise ValueError(f'unsupported dimensions: {dims}')
|
||||
|
||||
|
||||
class HybridConditioner(nn.Module):
|
||||
@@ -253,7 +286,8 @@ class HybridConditioner(nn.Module):
|
||||
def __init__(self, c_concat_config, c_crossattn_config):
|
||||
super().__init__()
|
||||
self.concat_conditioner = instantiate_from_config(c_concat_config)
|
||||
self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
|
||||
self.crossattn_conditioner = instantiate_from_config(
|
||||
c_crossattn_config)
|
||||
|
||||
def forward(self, c_concat, c_crossattn):
|
||||
c_concat = self.concat_conditioner(c_concat)
|
||||
@@ -262,6 +296,13 @@ class HybridConditioner(nn.Module):
|
||||
|
||||
|
||||
def noise_like(shape, device, repeat=False):
|
||||
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
||||
noise = lambda: torch.randn(shape, device=device)
|
||||
return repeat_noise() if repeat else noise()
|
||||
|
||||
def repeat_noise():
|
||||
return torch.randn((1, *shape[1:]),
|
||||
device=device).repeat(shape[0],
|
||||
*((1, ) * (len(shape) - 1)))
|
||||
|
||||
def noise():
|
||||
return torch.randn(shape, device=device)
|
||||
|
||||
return repeat_noise() if repeat else noise()
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class AbstractDistribution:
|
||||
|
||||
def sample(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -11,6 +12,7 @@ class AbstractDistribution:
|
||||
|
||||
|
||||
class DiracDistribution(AbstractDistribution):
|
||||
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
@@ -22,6 +24,7 @@ class DiracDistribution(AbstractDistribution):
|
||||
|
||||
|
||||
class DiagonalGaussianDistribution(object):
|
||||
|
||||
def __init__(self, parameters, deterministic=False):
|
||||
self.parameters = parameters
|
||||
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
||||
@@ -30,10 +33,12 @@ class DiagonalGaussianDistribution(object):
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
self.var = torch.exp(self.logvar)
|
||||
if self.deterministic:
|
||||
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
||||
self.var = self.std = torch.zeros_like(
|
||||
self.mean).to(device=self.parameters.device)
|
||||
|
||||
def sample(self):
|
||||
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
|
||||
x = self.mean + self.std * torch.randn(
|
||||
self.mean.shape).to(device=self.parameters.device)
|
||||
return x
|
||||
|
||||
def kl(self, other=None):
|
||||
@@ -41,21 +46,22 @@ class DiagonalGaussianDistribution(object):
|
||||
return torch.Tensor([0.])
|
||||
else:
|
||||
if other is None:
|
||||
return 0.5 * torch.sum(torch.pow(self.mean, 2)
|
||||
+ self.var - 1.0 - self.logvar,
|
||||
dim=[1, 2, 3])
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
||||
dim=[1, 2, 3])
|
||||
else:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean - other.mean, 2) / other.var
|
||||
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
||||
dim=[1, 2, 3])
|
||||
|
||||
def nll(self, sample, dims=[1,2,3]):
|
||||
def nll(self, sample, dims=[1, 2, 3]):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
logtwopi = np.log(2.0 * np.pi)
|
||||
return 0.5 * torch.sum(
|
||||
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
||||
logtwopi + self.logvar
|
||||
+ torch.pow(sample - self.mean, 2) / self.var,
|
||||
dim=dims)
|
||||
|
||||
def mode(self):
|
||||
@@ -64,7 +70,8 @@ class DiagonalGaussianDistribution(object):
|
||||
|
||||
def normal_kl(mean1, logvar1, mean2, logvar2):
|
||||
"""
|
||||
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
|
||||
source: https://github.com/openai/guided-diffusion/blob/
|
||||
27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
|
||||
Compute the KL divergence between two gaussians.
|
||||
Shapes are automatically broadcasted, so batches can be compared to
|
||||
scalars, among other use cases.
|
||||
@@ -74,7 +81,7 @@ def normal_kl(mean1, logvar1, mean2, logvar2):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
tensor = obj
|
||||
break
|
||||
assert tensor is not None, "at least one argument must be a Tensor"
|
||||
assert tensor is not None, 'at least one argument must be a Tensor'
|
||||
|
||||
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
||||
# Tensors, but it does not work for torch.exp().
|
||||
@@ -83,10 +90,7 @@ def normal_kl(mean1, logvar1, mean2, logvar2):
|
||||
for x in (logvar1, logvar2)
|
||||
]
|
||||
|
||||
return 0.5 * (
|
||||
-1.0
|
||||
+ logvar2
|
||||
- logvar1
|
||||
+ torch.exp(logvar1 - logvar2)
|
||||
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
|
||||
)
|
||||
# rewrite because of W504
|
||||
tmp = ((mean1 - mean2)**2) * torch.exp(-logvar2)
|
||||
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + tmp
|
||||
) # noqa
|
||||
|
||||
@@ -2,15 +2,17 @@ import hashlib
|
||||
import os
|
||||
import urllib
|
||||
import warnings
|
||||
from typing import Any, Union, List
|
||||
from pkg_resources import packaging
|
||||
from typing import Any, List, Union
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
||||
from pkg_resources import packaging
|
||||
from torchvision.transforms import (CenterCrop, Compose, Normalize, Resize,
|
||||
ToTensor)
|
||||
from tqdm import tqdm
|
||||
|
||||
from modelscope.models.cv.image_to_3d.ldm.modules.encoders.clip.model import build_model
|
||||
from modelscope.models.cv.image_to_3d.ldm.modules.encoders.clip.model import \
|
||||
build_model
|
||||
|
||||
try:
|
||||
from torchvision.transforms import InterpolationMode
|
||||
@@ -18,23 +20,40 @@ try:
|
||||
except ImportError:
|
||||
BICUBIC = Image.BICUBIC
|
||||
|
||||
if packaging.version.parse(
|
||||
torch.__version__) < packaging.version.parse('1.7.1'):
|
||||
warnings.warn('PyTorch version 1.7.1 or higher is recommended')
|
||||
|
||||
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
|
||||
warnings.warn("PyTorch version 1.7.1 or higher is recommended")
|
||||
|
||||
|
||||
__all__ = ["available_models", "load"]
|
||||
__all__ = ['available_models', 'load']
|
||||
|
||||
_MODELS = {
|
||||
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
|
||||
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
|
||||
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
|
||||
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
|
||||
"RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
|
||||
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
||||
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
|
||||
"ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
|
||||
"ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
|
||||
'RN50':
|
||||
'https://openaipublic.azureedge.net/clip/models/'
|
||||
'afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt',
|
||||
'RN101':
|
||||
'https://openaipublic.azureedge.net/clip/models/'
|
||||
'8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt',
|
||||
'RN50x4':
|
||||
'https://openaipublic.azureedge.net/clip/models/'
|
||||
'7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt',
|
||||
'RN50x16':
|
||||
'https://openaipublic.azureedge.net/clip/models/'
|
||||
'52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt',
|
||||
'RN50x64':
|
||||
'https://openaipublic.azureedge.net/clip/models/'
|
||||
'be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt',
|
||||
'ViT-B/32':
|
||||
'https://openaipublic.azureedge.net/clip/models/'
|
||||
'40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt',
|
||||
'ViT-B/16':
|
||||
'https://openaipublic.azureedge.net/clip/models/'
|
||||
'5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt',
|
||||
'ViT-L/14':
|
||||
'https://openaipublic.azureedge.net/clip/models/'
|
||||
'b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt',
|
||||
'ViT-L/14@336px':
|
||||
'https://openaipublic.azureedge.net/clip/models/'
|
||||
'3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt',
|
||||
}
|
||||
|
||||
|
||||
@@ -42,20 +61,30 @@ def _download(url: str, root: str):
|
||||
os.makedirs(root, exist_ok=True)
|
||||
filename = os.path.basename(url)
|
||||
|
||||
expected_sha256 = url.split("/")[-2]
|
||||
expected_sha256 = url.split('/')[-2]
|
||||
download_target = os.path.join(root, filename)
|
||||
|
||||
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
||||
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
||||
raise RuntimeError(
|
||||
f'{download_target} exists and is not a regular file')
|
||||
|
||||
if os.path.isfile(download_target):
|
||||
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
|
||||
if hashlib.sha256(open(download_target,
|
||||
'rb').read()).hexdigest() == expected_sha256:
|
||||
return download_target
|
||||
else:
|
||||
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
||||
warnings.warn(
|
||||
f'{download_target} exists, but the SHA256 checksum does not match; re-downloading the file'
|
||||
)
|
||||
|
||||
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
||||
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
|
||||
with urllib.request.urlopen(url) as source, open(download_target,
|
||||
'wb') as output:
|
||||
with tqdm(
|
||||
total=int(source.info().get('Content-Length')),
|
||||
ncols=80,
|
||||
unit='iB',
|
||||
unit_scale=True,
|
||||
unit_divisor=1024) as loop:
|
||||
while True:
|
||||
buffer = source.read(8192)
|
||||
if not buffer:
|
||||
@@ -64,14 +93,17 @@ def _download(url: str, root: str):
|
||||
output.write(buffer)
|
||||
loop.update(len(buffer))
|
||||
|
||||
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
|
||||
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
|
||||
if hashlib.sha256(open(download_target,
|
||||
'rb').read()).hexdigest() != expected_sha256:
|
||||
raise RuntimeError(
|
||||
'Model has been downloaded but the SHA256 checksum does not not match'
|
||||
)
|
||||
|
||||
return download_target
|
||||
|
||||
|
||||
def _convert_image_to_rgb(image):
|
||||
return image.convert("RGB")
|
||||
return image.convert('RGB')
|
||||
|
||||
|
||||
def _transform(n_px):
|
||||
@@ -80,7 +112,8 @@ def _transform(n_px):
|
||||
CenterCrop(n_px),
|
||||
_convert_image_to_rgb,
|
||||
ToTensor(),
|
||||
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
||||
Normalize((0.48145466, 0.4578275, 0.40821073),
|
||||
(0.26862954, 0.26130258, 0.27577711)),
|
||||
])
|
||||
|
||||
|
||||
@@ -89,7 +122,11 @@ def available_models() -> List[str]:
|
||||
return list(_MODELS.keys())
|
||||
|
||||
|
||||
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
|
||||
def load(name: str,
|
||||
device: Union[str, torch.device] = 'cuda'
|
||||
if torch.cuda.is_available() else 'cpu',
|
||||
jit: bool = False,
|
||||
download_root: str = None):
|
||||
"""Load a CLIP model
|
||||
|
||||
Parameters
|
||||
@@ -115,37 +152,47 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
||||
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
||||
"""
|
||||
if name in _MODELS:
|
||||
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
|
||||
model_path = _download(
|
||||
_MODELS[name], download_root
|
||||
or os.path.expanduser('~/.cache/clip'))
|
||||
elif os.path.isfile(name):
|
||||
model_path = name
|
||||
else:
|
||||
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
||||
raise RuntimeError(
|
||||
f'Model {name} not found; available models = {available_models()}')
|
||||
|
||||
with open(model_path, 'rb') as opened_file:
|
||||
try:
|
||||
# loading JIT archive
|
||||
model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
|
||||
model = torch.jit.load(
|
||||
opened_file, map_location=device if jit else 'cpu').eval()
|
||||
state_dict = None
|
||||
except RuntimeError:
|
||||
# loading saved state dict
|
||||
if jit:
|
||||
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
||||
warnings.warn(
|
||||
f'File {model_path} is not a JIT archive. Loading as a state dict instead'
|
||||
)
|
||||
jit = False
|
||||
state_dict = torch.load(opened_file, map_location="cpu")
|
||||
state_dict = torch.load(opened_file, map_location='cpu')
|
||||
|
||||
if not jit:
|
||||
model = build_model(state_dict or model.state_dict()).to(device)
|
||||
if str(device) == "cpu":
|
||||
if str(device) == 'cpu':
|
||||
model.float()
|
||||
return model, _transform(model.visual.input_resolution)
|
||||
|
||||
# patch the device names
|
||||
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
||||
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
||||
device_holder = torch.jit.trace(
|
||||
lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
||||
device_node = [
|
||||
n for n in device_holder.graph.findAllNodes('prim::Constant')
|
||||
if 'Device' in repr(n)
|
||||
][-1]
|
||||
|
||||
def _node_get(node: torch._C.Node, key: str):
|
||||
"""Gets attributes of a node which is polymorphic over return type.
|
||||
|
||||
|
||||
From https://github.com/pytorch/pytorch/pull/82628
|
||||
"""
|
||||
sel = node.kindOf(key)
|
||||
@@ -153,16 +200,17 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
||||
|
||||
def patch_device(module):
|
||||
try:
|
||||
graphs = [module.graph] if hasattr(module, "graph") else []
|
||||
graphs = [module.graph] if hasattr(module, 'graph') else []
|
||||
except RuntimeError:
|
||||
graphs = []
|
||||
|
||||
if hasattr(module, "forward1"):
|
||||
if hasattr(module, 'forward1'):
|
||||
graphs.append(module.forward1.graph)
|
||||
|
||||
for graph in graphs:
|
||||
for node in graph.findAllNodes("prim::Constant"):
|
||||
if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"):
|
||||
for node in graph.findAllNodes('prim::Constant'):
|
||||
if 'value' in node.attributeNames() and str(
|
||||
_node_get(node, 'value')).startswith('cuda'):
|
||||
node.copyAttributes(device_node)
|
||||
|
||||
model.apply(patch_device)
|
||||
@@ -170,25 +218,28 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
||||
patch_device(model.encode_text)
|
||||
|
||||
# patch dtype to float32 on CPU
|
||||
if str(device) == "cpu":
|
||||
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
||||
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
||||
if str(device) == 'cpu':
|
||||
float_holder = torch.jit.trace(
|
||||
lambda: torch.ones([]).float(), example_inputs=[])
|
||||
float_input = list(float_holder.graph.findNode('aten::to').inputs())[1]
|
||||
float_node = float_input.node()
|
||||
|
||||
def patch_float(module):
|
||||
try:
|
||||
graphs = [module.graph] if hasattr(module, "graph") else []
|
||||
graphs = [module.graph] if hasattr(module, 'graph') else []
|
||||
except RuntimeError:
|
||||
graphs = []
|
||||
|
||||
if hasattr(module, "forward1"):
|
||||
if hasattr(module, 'forward1'):
|
||||
graphs.append(module.forward1.graph)
|
||||
|
||||
for graph in graphs:
|
||||
for node in graph.findAllNodes("aten::to"):
|
||||
for node in graph.findAllNodes('aten::to'):
|
||||
inputs = list(node.inputs())
|
||||
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
||||
if _node_get(inputs[i].node(), "value") == 5:
|
||||
for i in [
|
||||
1, 2
|
||||
]: # dtype can be the second or third argument to aten::to()
|
||||
if _node_get(inputs[i].node(), 'value') == 5:
|
||||
inputs[i].node().copyAttributes(float_node)
|
||||
|
||||
model.apply(patch_float)
|
||||
|
||||
@@ -33,11 +33,16 @@ class Bottleneck(nn.Module):
|
||||
|
||||
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
||||
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
||||
self.downsample = nn.Sequential(OrderedDict([
|
||||
("-1", nn.AvgPool2d(stride)),
|
||||
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
||||
("1", nn.BatchNorm2d(planes * self.expansion))
|
||||
]))
|
||||
self.downsample = nn.Sequential(
|
||||
OrderedDict([('-1', nn.AvgPool2d(stride)),
|
||||
('0',
|
||||
nn.Conv2d(
|
||||
inplanes,
|
||||
planes * self.expansion,
|
||||
1,
|
||||
stride=1,
|
||||
bias=False)),
|
||||
('1', nn.BatchNorm2d(planes * self.expansion))]))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
identity = x
|
||||
@@ -56,9 +61,15 @@ class Bottleneck(nn.Module):
|
||||
|
||||
|
||||
class AttentionPool2d(nn.Module):
|
||||
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
||||
|
||||
def __init__(self,
|
||||
spacial_dim: int,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
output_dim: int = None):
|
||||
super().__init__()
|
||||
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
||||
self.positional_embedding = nn.Parameter(
|
||||
torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5)
|
||||
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
||||
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
||||
@@ -70,14 +81,17 @@ class AttentionPool2d(nn.Module):
|
||||
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
||||
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
||||
x, _ = F.multi_head_attention_forward(
|
||||
query=x[:1], key=x, value=x,
|
||||
query=x[:1],
|
||||
key=x,
|
||||
value=x,
|
||||
embed_dim_to_check=x.shape[-1],
|
||||
num_heads=self.num_heads,
|
||||
q_proj_weight=self.q_proj.weight,
|
||||
k_proj_weight=self.k_proj.weight,
|
||||
v_proj_weight=self.v_proj.weight,
|
||||
in_proj_weight=None,
|
||||
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
||||
in_proj_bias=torch.cat(
|
||||
[self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
||||
bias_k=None,
|
||||
bias_v=None,
|
||||
add_zero_attn=False,
|
||||
@@ -86,8 +100,7 @@ class AttentionPool2d(nn.Module):
|
||||
out_proj_bias=self.c_proj.bias,
|
||||
use_separate_proj_weight=True,
|
||||
training=self.training,
|
||||
need_weights=False
|
||||
)
|
||||
need_weights=False)
|
||||
return x.squeeze(0)
|
||||
|
||||
|
||||
@@ -99,19 +112,27 @@ class ModifiedResNet(nn.Module):
|
||||
- The final pooling layer is a QKV attention instead of an average pool
|
||||
"""
|
||||
|
||||
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
||||
def __init__(self,
|
||||
layers,
|
||||
output_dim,
|
||||
heads,
|
||||
input_resolution=224,
|
||||
width=64):
|
||||
super().__init__()
|
||||
self.output_dim = output_dim
|
||||
self.input_resolution = input_resolution
|
||||
|
||||
# the 3-layer stem
|
||||
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
||||
self.conv1 = nn.Conv2d(
|
||||
3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(width // 2)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
||||
self.conv2 = nn.Conv2d(
|
||||
width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(width // 2)
|
||||
self.relu2 = nn.ReLU(inplace=True)
|
||||
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
||||
self.conv3 = nn.Conv2d(
|
||||
width // 2, width, kernel_size=3, padding=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(width)
|
||||
self.relu3 = nn.ReLU(inplace=True)
|
||||
self.avgpool = nn.AvgPool2d(2)
|
||||
@@ -124,7 +145,8 @@ class ModifiedResNet(nn.Module):
|
||||
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
||||
|
||||
embed_dim = width * 32 # the ResNet feature dimension
|
||||
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
||||
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim,
|
||||
heads, output_dim)
|
||||
|
||||
def _make_layer(self, planes, blocks, stride=1):
|
||||
layers = [Bottleneck(self._inplanes, planes, stride)]
|
||||
@@ -136,6 +158,7 @@ class ModifiedResNet(nn.Module):
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
def stem(x):
|
||||
x = self.relu1(self.bn1(self.conv1(x)))
|
||||
x = self.relu2(self.bn2(self.conv2(x)))
|
||||
@@ -164,27 +187,34 @@ class LayerNorm(nn.LayerNorm):
|
||||
|
||||
|
||||
class QuickGELU(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
||||
|
||||
def __init__(self,
|
||||
d_model: int,
|
||||
n_head: int,
|
||||
attn_mask: torch.Tensor = None):
|
||||
super().__init__()
|
||||
|
||||
self.attn = nn.MultiheadAttention(d_model, n_head)
|
||||
self.ln_1 = LayerNorm(d_model)
|
||||
self.mlp = nn.Sequential(OrderedDict([
|
||||
("c_fc", nn.Linear(d_model, d_model * 4)),
|
||||
("gelu", QuickGELU()),
|
||||
("c_proj", nn.Linear(d_model * 4, d_model))
|
||||
]))
|
||||
self.mlp = nn.Sequential(
|
||||
OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)),
|
||||
('gelu', QuickGELU()),
|
||||
('c_proj', nn.Linear(d_model * 4, d_model))]))
|
||||
self.ln_2 = LayerNorm(d_model)
|
||||
self.attn_mask = attn_mask
|
||||
|
||||
def attention(self, x: torch.Tensor):
|
||||
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
||||
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
||||
self.attn_mask = self.attn_mask.to(
|
||||
dtype=x.dtype,
|
||||
device=x.device) if self.attn_mask is not None else None
|
||||
return self.attn(
|
||||
x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = x + self.attention(self.ln_1(x))
|
||||
@@ -193,26 +223,42 @@ class ResidualAttentionBlock(nn.Module):
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
|
||||
|
||||
def __init__(self,
|
||||
width: int,
|
||||
layers: int,
|
||||
heads: int,
|
||||
attn_mask: torch.Tensor = None):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.layers = layers
|
||||
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
|
||||
self.resblocks = nn.Sequential(*[
|
||||
ResidualAttentionBlock(width, heads, attn_mask)
|
||||
for _ in range(layers)
|
||||
])
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return self.resblocks(x)
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
|
||||
|
||||
def __init__(self, input_resolution: int, patch_size: int, width: int,
|
||||
layers: int, heads: int, output_dim: int):
|
||||
super().__init__()
|
||||
self.input_resolution = input_resolution
|
||||
self.output_dim = output_dim
|
||||
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_channels=3,
|
||||
out_channels=width,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias=False)
|
||||
|
||||
scale = width ** -0.5
|
||||
scale = width**-0.5
|
||||
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
||||
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
||||
self.positional_embedding = nn.Parameter(scale * torch.randn(
|
||||
(input_resolution // patch_size)**2 + 1, width))
|
||||
self.ln_pre = LayerNorm(width)
|
||||
|
||||
self.transformer = Transformer(width, layers, heads)
|
||||
@@ -222,9 +268,15 @@ class VisionTransformer(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = self.conv1(x) # shape = [*, width, grid, grid]
|
||||
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
||||
x = x.reshape(x.shape[0], x.shape[1],
|
||||
-1) # shape = [*, width, grid ** 2]
|
||||
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
||||
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
||||
|
||||
# rewrite because of E126
|
||||
tmp = self.class_embedding.to(x.dtype) + torch.zeros(
|
||||
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device) # noqs
|
||||
x = torch.cat([tmp, x], dim=1)
|
||||
# shape = [*, grid ** 2 + 1, width]
|
||||
x = x + self.positional_embedding.to(x.dtype)
|
||||
x = self.ln_pre(x)
|
||||
|
||||
@@ -241,20 +293,21 @@ class VisionTransformer(nn.Module):
|
||||
|
||||
|
||||
class CLIP(nn.Module):
|
||||
def __init__(self,
|
||||
embed_dim: int,
|
||||
# vision
|
||||
image_resolution: int,
|
||||
vision_layers: Union[Tuple[int, int, int, int], int],
|
||||
vision_width: int,
|
||||
vision_patch_size: int,
|
||||
# text
|
||||
context_length: int,
|
||||
vocab_size: int,
|
||||
transformer_width: int,
|
||||
transformer_heads: int,
|
||||
transformer_layers: int
|
||||
):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
# vision
|
||||
image_resolution: int,
|
||||
vision_layers: Union[Tuple[int, int, int, int], int],
|
||||
vision_width: int,
|
||||
vision_patch_size: int,
|
||||
# text
|
||||
context_length: int,
|
||||
vocab_size: int,
|
||||
transformer_width: int,
|
||||
transformer_heads: int,
|
||||
transformer_layers: int):
|
||||
super().__init__()
|
||||
|
||||
self.context_length = context_length
|
||||
@@ -266,8 +319,7 @@ class CLIP(nn.Module):
|
||||
output_dim=embed_dim,
|
||||
heads=vision_heads,
|
||||
input_resolution=image_resolution,
|
||||
width=vision_width
|
||||
)
|
||||
width=vision_width)
|
||||
else:
|
||||
vision_heads = vision_width // 64
|
||||
self.visual = VisionTransformer(
|
||||
@@ -276,22 +328,22 @@ class CLIP(nn.Module):
|
||||
width=vision_width,
|
||||
layers=vision_layers,
|
||||
heads=vision_heads,
|
||||
output_dim=embed_dim
|
||||
)
|
||||
output_dim=embed_dim)
|
||||
|
||||
self.transformer = Transformer(
|
||||
width=transformer_width,
|
||||
layers=transformer_layers,
|
||||
heads=transformer_heads,
|
||||
attn_mask=self.build_attention_mask()
|
||||
)
|
||||
attn_mask=self.build_attention_mask())
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
||||
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
||||
self.positional_embedding = nn.Parameter(
|
||||
torch.empty(self.context_length, transformer_width))
|
||||
self.ln_final = LayerNorm(transformer_width)
|
||||
|
||||
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
||||
self.text_projection = nn.Parameter(
|
||||
torch.empty(transformer_width, embed_dim))
|
||||
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
||||
|
||||
self.initialize_parameters()
|
||||
@@ -302,20 +354,24 @@ class CLIP(nn.Module):
|
||||
|
||||
if isinstance(self.visual, ModifiedResNet):
|
||||
if self.visual.attnpool is not None:
|
||||
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
||||
std = self.visual.attnpool.c_proj.in_features**-0.5
|
||||
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
||||
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
||||
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
||||
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
||||
|
||||
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
||||
for resnet_block in [
|
||||
self.visual.layer1, self.visual.layer2, self.visual.layer3,
|
||||
self.visual.layer4
|
||||
]:
|
||||
for name, param in resnet_block.named_parameters():
|
||||
if name.endswith("bn3.weight"):
|
||||
if name.endswith('bn3.weight'):
|
||||
nn.init.zeros_(param)
|
||||
|
||||
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
||||
attn_std = self.transformer.width ** -0.5
|
||||
fc_std = (2 * self.transformer.width) ** -0.5
|
||||
proj_std = (self.transformer.width**-0.5) * (
|
||||
(2 * self.transformer.layers)**-0.5)
|
||||
attn_std = self.transformer.width**-0.5
|
||||
fc_std = (2 * self.transformer.width)**-0.5
|
||||
for block in self.transformer.resblocks:
|
||||
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
||||
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
||||
@@ -323,13 +379,14 @@ class CLIP(nn.Module):
|
||||
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
||||
|
||||
if self.text_projection is not None:
|
||||
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
||||
nn.init.normal_(
|
||||
self.text_projection, std=self.transformer.width**-0.5)
|
||||
|
||||
def build_attention_mask(self):
|
||||
# lazily create causal attention mask, with full attention between the vision tokens
|
||||
# pytorch uses additive attention mask; fill with -inf
|
||||
mask = torch.empty(self.context_length, self.context_length)
|
||||
mask.fill_(float("-inf"))
|
||||
mask.fill_(float('-inf'))
|
||||
mask.triu_(1) # zero out the lower diagonal
|
||||
return mask
|
||||
|
||||
@@ -341,7 +398,8 @@ class CLIP(nn.Module):
|
||||
return self.visual(image.type(self.dtype))
|
||||
|
||||
def encode_text(self, text):
|
||||
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
||||
x = self.token_embedding(text).type(
|
||||
self.dtype) # [batch_size, n_ctx, d_model]
|
||||
|
||||
x = x + self.positional_embedding.type(self.dtype)
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
@@ -351,7 +409,8 @@ class CLIP(nn.Module):
|
||||
|
||||
# x.shape = [batch_size, n_ctx, transformer.width]
|
||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
||||
x = x[torch.arange(x.shape[0]),
|
||||
text.argmax(dim=-1)] @ self.text_projection
|
||||
|
||||
return x
|
||||
|
||||
@@ -360,7 +419,8 @@ class CLIP(nn.Module):
|
||||
text_features = self.encode_text(text)
|
||||
|
||||
# normalized features
|
||||
image_features = image_features / image_features.norm(dim=1, keepdim=True)
|
||||
image_features = image_features / image_features.norm(
|
||||
dim=1, keepdim=True)
|
||||
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
||||
|
||||
# cosine similarity as logits
|
||||
@@ -375,21 +435,24 @@ class CLIP(nn.Module):
|
||||
def convert_weights(model: nn.Module):
|
||||
"""Convert applicable model parameters to fp16"""
|
||||
|
||||
def _convert_weights_to_fp16(l):
|
||||
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
||||
l.weight.data = l.weight.data.half()
|
||||
if l.bias is not None:
|
||||
l.bias.data = l.bias.data.half()
|
||||
def _convert_weights_to_fp16(_l):
|
||||
if isinstance(_l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
||||
_l.weight.data = _l.weight.data.half()
|
||||
if _l.bias is not None:
|
||||
_l.bias.data = _l.bias.data.half()
|
||||
|
||||
if isinstance(l, nn.MultiheadAttention):
|
||||
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
||||
tensor = getattr(l, attr)
|
||||
if isinstance(_l, nn.MultiheadAttention):
|
||||
for attr in [
|
||||
*[f'{s}_proj_weight' for s in ['in', 'q', 'k', 'v']],
|
||||
'in_proj_bias', 'bias_k', 'bias_v'
|
||||
]:
|
||||
tensor = getattr(_l, attr)
|
||||
if tensor is not None:
|
||||
tensor.data = tensor.data.half()
|
||||
|
||||
for name in ["text_projection", "proj"]:
|
||||
if hasattr(l, name):
|
||||
attr = getattr(l, name)
|
||||
for name in ['text_projection', 'proj']:
|
||||
if hasattr(_l, name):
|
||||
attr = getattr(_l, name)
|
||||
if attr is not None:
|
||||
attr.data = attr.data.half()
|
||||
|
||||
@@ -397,37 +460,51 @@ def convert_weights(model: nn.Module):
|
||||
|
||||
|
||||
def build_model(state_dict: dict):
|
||||
vit = "visual.proj" in state_dict
|
||||
vit = 'visual.proj' in state_dict
|
||||
|
||||
if vit:
|
||||
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
||||
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
||||
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
||||
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
||||
vision_width = state_dict['visual.conv1.weight'].shape[0]
|
||||
vision_layers = len([
|
||||
k for k in state_dict.keys()
|
||||
if k.startswith('visual.') and k.endswith('.attn.in_proj_weight')
|
||||
])
|
||||
vision_patch_size = state_dict['visual.conv1.weight'].shape[-1]
|
||||
grid_size = round(
|
||||
(state_dict['visual.positional_embedding'].shape[0] - 1)**0.5)
|
||||
image_resolution = vision_patch_size * grid_size
|
||||
else:
|
||||
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
||||
counts: list = [
|
||||
len(
|
||||
set(
|
||||
k.split('.')[2] for k in state_dict
|
||||
if k.startswith(f'visual.layer{b}')))
|
||||
for b in [1, 2, 3, 4]
|
||||
]
|
||||
vision_layers = tuple(counts)
|
||||
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
||||
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
||||
vision_width = state_dict['visual.layer1.0.conv1.weight'].shape[0]
|
||||
output_width = round(
|
||||
(state_dict['visual.attnpool.positional_embedding'].shape[0]
|
||||
- 1)**0.5)
|
||||
vision_patch_size = None
|
||||
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
||||
assert output_width**2 + 1 == state_dict[
|
||||
'visual.attnpool.positional_embedding'].shape[0]
|
||||
image_resolution = output_width * 32
|
||||
|
||||
embed_dim = state_dict["text_projection"].shape[1]
|
||||
context_length = state_dict["positional_embedding"].shape[0]
|
||||
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
||||
transformer_width = state_dict["ln_final.weight"].shape[0]
|
||||
embed_dim = state_dict['text_projection'].shape[1]
|
||||
context_length = state_dict['positional_embedding'].shape[0]
|
||||
vocab_size = state_dict['token_embedding.weight'].shape[0]
|
||||
transformer_width = state_dict['ln_final.weight'].shape[0]
|
||||
transformer_heads = transformer_width // 64
|
||||
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
|
||||
transformer_layers = len(
|
||||
set(
|
||||
k.split('.')[2] for k in state_dict
|
||||
if k.startswith('transformer.resblocks')))
|
||||
|
||||
model = CLIP(
|
||||
embed_dim,
|
||||
image_resolution, vision_layers, vision_width, vision_patch_size,
|
||||
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
|
||||
)
|
||||
model = CLIP(embed_dim, image_resolution, vision_layers, vision_width,
|
||||
vision_patch_size, context_length, vocab_size,
|
||||
transformer_width, transformer_heads, transformer_layers)
|
||||
|
||||
for key in ["input_resolution", "context_length", "vocab_size"]:
|
||||
for key in ['input_resolution', 'context_length', 'vocab_size']:
|
||||
if key in state_dict:
|
||||
del state_dict[key]
|
||||
|
||||
|
||||
@@ -9,7 +9,9 @@ import regex as re
|
||||
|
||||
@lru_cache()
|
||||
def default_bpe():
|
||||
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
||||
return os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
'bpe_simple_vocab_16e6.txt.gz')
|
||||
|
||||
|
||||
@lru_cache()
|
||||
@@ -23,13 +25,17 @@ def bytes_to_unicode():
|
||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||
"""
|
||||
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
||||
bs = list(range(ord('!'),
|
||||
ord('~') + 1)) + list(range(
|
||||
ord('¡'),
|
||||
ord('¬') + 1)) + list(range(ord('®'),
|
||||
ord('ÿ') + 1))
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2**8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2**8+n)
|
||||
cs.append(2**8 + n)
|
||||
n += 1
|
||||
cs = [chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
@@ -60,34 +66,41 @@ def whitespace_clean(text):
|
||||
|
||||
|
||||
class SimpleTokenizer(object):
|
||||
|
||||
def __init__(self, bpe_path: str = default_bpe()):
|
||||
self.byte_encoder = bytes_to_unicode()
|
||||
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
||||
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
||||
merges = merges[1:49152-256-2+1]
|
||||
merges = gzip.open(bpe_path).read().decode('utf-8').split('\n')
|
||||
merges = merges[1:49152 - 256 - 2 + 1]
|
||||
merges = [tuple(merge.split()) for merge in merges]
|
||||
vocab = list(bytes_to_unicode().values())
|
||||
vocab = vocab + [v+'</w>' for v in vocab]
|
||||
vocab = vocab + [v + '</w>' for v in vocab]
|
||||
for merge in merges:
|
||||
vocab.append(''.join(merge))
|
||||
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
||||
self.encoder = dict(zip(vocab, range(len(vocab))))
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
||||
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
||||
self.cache = {
|
||||
'<|startoftext|>': '<|startoftext|>',
|
||||
'<|endoftext|>': '<|endoftext|>'
|
||||
}
|
||||
self.pat = re.compile(
|
||||
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
|
||||
re.IGNORECASE)
|
||||
|
||||
def bpe(self, token):
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
||||
word = tuple(token[:-1]) + (token[-1] + '</w>', )
|
||||
pairs = get_pairs(word)
|
||||
|
||||
if not pairs:
|
||||
return token+'</w>'
|
||||
return token + '</w>'
|
||||
|
||||
while True:
|
||||
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
||||
bigram = min(
|
||||
pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
@@ -98,12 +111,13 @@ class SimpleTokenizer(object):
|
||||
j = word.index(first, i)
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
except:
|
||||
except Exception:
|
||||
new_word.extend(word[i:])
|
||||
break
|
||||
|
||||
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
||||
new_word.append(first+second)
|
||||
if word[i] == first and i < len(word) - 1 and word[
|
||||
i + 1] == second:
|
||||
new_word.append(first + second)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
@@ -122,11 +136,14 @@ class SimpleTokenizer(object):
|
||||
bpe_tokens = []
|
||||
text = whitespace_clean(basic_clean(text)).lower()
|
||||
for token in re.findall(self.pat, text):
|
||||
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
||||
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
||||
token = ''.join(self.byte_encoder[b]
|
||||
for b in token.encode('utf-8'))
|
||||
bpe_tokens.extend(self.encoder[bpe_token]
|
||||
for bpe_token in self.bpe(token).split(' '))
|
||||
return bpe_tokens
|
||||
|
||||
def decode(self, tokens):
|
||||
text = ''.join([self.decoder[token] for token in tokens])
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode(
|
||||
'utf-8', errors='replace').replace('</w>', ' ')
|
||||
return text
|
||||
|
||||
@@ -1,28 +1,45 @@
|
||||
import random
|
||||
from functools import partial
|
||||
|
||||
import kornia
|
||||
import kornia.augmentation as K
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
import kornia
|
||||
import torch.nn.functional as F
|
||||
from torchvision import transforms
|
||||
from transformers import (CLIPTextModel, CLIPTokenizer, CLIPVisionModel,
|
||||
T5EncoderModel, T5Tokenizer)
|
||||
|
||||
from modelscope.models.cv.image_to_3d.ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
|
||||
from modelscope.models.cv.image_to_3d.ldm.util import default
|
||||
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.util import (
|
||||
extract_into_tensor, make_beta_schedule, noise_like)
|
||||
# import clip
|
||||
from modelscope.models.cv.image_to_3d.ldm.modules.encoders import clip
|
||||
# TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
|
||||
from modelscope.models.cv.image_to_3d.ldm.modules.x_transformer import (
|
||||
Encoder, TransformerWrapper)
|
||||
from modelscope.models.cv.image_to_3d.ldm.thirdp.psp.id_loss import IDFeatures
|
||||
from modelscope.models.cv.image_to_3d.ldm.util import (default,
|
||||
instantiate_from_config)
|
||||
|
||||
|
||||
class AbstractEncoder(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def encode(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class IdentityEncoder(AbstractEncoder):
|
||||
|
||||
def encode(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class FaceClipEncoder(AbstractEncoder):
|
||||
|
||||
def __init__(self, augment=True, retreival_key=None):
|
||||
super().__init__()
|
||||
self.encoder = FrozenCLIPImageEmbedder()
|
||||
@@ -35,16 +52,16 @@ class FaceClipEncoder(AbstractEncoder):
|
||||
x_offset = 125
|
||||
if self.retreival_key:
|
||||
# Assumes retrieved image are packed into the second half of channels
|
||||
face = img[:,3:,190:440,x_offset:(512-x_offset)]
|
||||
other = img[:,:3,...].clone()
|
||||
face = img[:, 3:, 190:440, x_offset:(512 - x_offset)]
|
||||
other = img[:, :3, ...].clone()
|
||||
else:
|
||||
face = img[:,:,190:440,x_offset:(512-x_offset)]
|
||||
face = img[:, :, 190:440, x_offset:(512 - x_offset)]
|
||||
other = img.clone()
|
||||
|
||||
if self.augment:
|
||||
face = K.RandomHorizontalFlip()(face)
|
||||
|
||||
other[:,:,190:440,x_offset:(512-x_offset)] *= 0
|
||||
other[:, :, 190:440, x_offset:(512 - x_offset)] *= 0
|
||||
encodings = [
|
||||
self.encoder.encode(face),
|
||||
self.encoder.encode(other),
|
||||
@@ -55,26 +72,32 @@ class FaceClipEncoder(AbstractEncoder):
|
||||
def encode(self, img):
|
||||
if isinstance(img, list):
|
||||
# Uncondition
|
||||
return torch.zeros((1, 2, 768), device=self.encoder.model.visual.conv1.weight.device)
|
||||
return torch.zeros(
|
||||
(1, 2, 768),
|
||||
device=self.encoder.model.visual.conv1.weight.device)
|
||||
|
||||
return self(img)
|
||||
|
||||
|
||||
class FaceIdClipEncoder(AbstractEncoder):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.encoder = FrozenCLIPImageEmbedder()
|
||||
for p in self.encoder.parameters():
|
||||
p.requires_grad = False
|
||||
self.id = FrozenFaceEncoder("/home/jpinkney/code/stable-diffusion/model_ir_se50.pth", augment=True)
|
||||
self.id = FrozenFaceEncoder(
|
||||
'/home/jpinkney/code/stable-diffusion/model_ir_se50.pth',
|
||||
augment=True)
|
||||
|
||||
def forward(self, img):
|
||||
encodings = []
|
||||
with torch.no_grad():
|
||||
face = kornia.geometry.resize(img, (256, 256),
|
||||
interpolation='bilinear', align_corners=True)
|
||||
face = kornia.geometry.resize(
|
||||
img, (256, 256), interpolation='bilinear', align_corners=True)
|
||||
|
||||
other = img.clone()
|
||||
other[:,:,184:452,122:396] *= 0
|
||||
other[:, :, 184:452, 122:396] *= 0
|
||||
encodings = [
|
||||
self.id.encode(face),
|
||||
self.encoder.encode(other),
|
||||
@@ -85,11 +108,15 @@ class FaceIdClipEncoder(AbstractEncoder):
|
||||
def encode(self, img):
|
||||
if isinstance(img, list):
|
||||
# Uncondition
|
||||
return torch.zeros((1, 2, 768), device=self.encoder.model.visual.conv1.weight.device)
|
||||
return torch.zeros(
|
||||
(1, 2, 768),
|
||||
device=self.encoder.model.visual.conv1.weight.device)
|
||||
|
||||
return self(img)
|
||||
|
||||
|
||||
class ClassEmbedder(nn.Module):
|
||||
|
||||
def __init__(self, embed_dim, n_classes=1000, key='class'):
|
||||
super().__init__()
|
||||
self.key = key
|
||||
@@ -106,11 +133,19 @@ class ClassEmbedder(nn.Module):
|
||||
|
||||
class TransformerEmbedder(AbstractEncoder):
|
||||
"""Some transformer encoder layers"""
|
||||
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
|
||||
|
||||
def __init__(self,
|
||||
n_embed,
|
||||
n_layer,
|
||||
vocab_size,
|
||||
max_seq_len=77,
|
||||
device='cuda'):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
|
||||
attn_layers=Encoder(dim=n_embed, depth=n_layer))
|
||||
self.transformer = TransformerWrapper(
|
||||
num_tokens=vocab_size,
|
||||
max_seq_len=max_seq_len,
|
||||
attn_layers=Encoder(dim=n_embed, depth=n_layer))
|
||||
|
||||
def forward(self, tokens):
|
||||
tokens = tokens.to(self.device) # meh
|
||||
@@ -123,18 +158,25 @@ class TransformerEmbedder(AbstractEncoder):
|
||||
|
||||
class BERTTokenizer(AbstractEncoder):
|
||||
""" Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
|
||||
def __init__(self, device="cuda", vq_interface=True, max_length=77):
|
||||
|
||||
def __init__(self, device='cuda', vq_interface=True, max_length=77):
|
||||
super().__init__()
|
||||
from transformers import BertTokenizerFast # TODO: add to reuquirements
|
||||
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
||||
self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
|
||||
self.device = device
|
||||
self.vq_interface = vq_interface
|
||||
self.max_length = max_length
|
||||
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding='max_length',
|
||||
return_tensors='pt')
|
||||
tokens = batch_encoding['input_ids'].to(self.device)
|
||||
return tokens
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -150,20 +192,30 @@ class BERTTokenizer(AbstractEncoder):
|
||||
|
||||
class BERTEmbedder(AbstractEncoder):
|
||||
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
|
||||
def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
|
||||
device="cuda",use_tokenizer=True, embedding_dropout=0.0):
|
||||
|
||||
def __init__(self,
|
||||
n_embed,
|
||||
n_layer,
|
||||
vocab_size=30522,
|
||||
max_seq_len=77,
|
||||
device='cuda',
|
||||
use_tokenizer=True,
|
||||
embedding_dropout=0.0):
|
||||
super().__init__()
|
||||
self.use_tknz_fn = use_tokenizer
|
||||
if self.use_tknz_fn:
|
||||
self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
|
||||
self.tknz_fn = BERTTokenizer(
|
||||
vq_interface=False, max_length=max_seq_len)
|
||||
self.device = device
|
||||
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
|
||||
attn_layers=Encoder(dim=n_embed, depth=n_layer),
|
||||
emb_dropout=embedding_dropout)
|
||||
self.transformer = TransformerWrapper(
|
||||
num_tokens=vocab_size,
|
||||
max_seq_len=max_seq_len,
|
||||
attn_layers=Encoder(dim=n_embed, depth=n_layer),
|
||||
emb_dropout=embedding_dropout)
|
||||
|
||||
def forward(self, text):
|
||||
if self.use_tknz_fn:
|
||||
tokens = self.tknz_fn(text)#.to(self.device)
|
||||
tokens = self.tknz_fn(text) # .to(self.device)
|
||||
else:
|
||||
tokens = text
|
||||
z = self.transformer(tokens, return_embeddings=True)
|
||||
@@ -174,8 +226,6 @@ class BERTEmbedder(AbstractEncoder):
|
||||
return self(text)
|
||||
|
||||
|
||||
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
|
||||
|
||||
def disabled_train(self, mode=True):
|
||||
"""Overwrite model.train with this function to make sure train/eval mode
|
||||
does not change anymore."""
|
||||
@@ -184,24 +234,41 @@ def disabled_train(self, mode=True):
|
||||
|
||||
class FrozenT5Embedder(AbstractEncoder):
|
||||
"""Uses the T5 transformer encoder for text"""
|
||||
def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
|
||||
|
||||
def __init__(self,
|
||||
version='google/t5-v1_1-large',
|
||||
device='cuda',
|
||||
max_length=77
|
||||
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
|
||||
super().__init__()
|
||||
self.tokenizer = T5Tokenizer.from_pretrained(version, cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models')
|
||||
self.transformer = T5EncoderModel.from_pretrained(version, cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models')
|
||||
self.tokenizer = T5Tokenizer.from_pretrained(
|
||||
version,
|
||||
cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models'
|
||||
)
|
||||
self.transformer = T5EncoderModel.from_pretrained(
|
||||
version,
|
||||
cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models'
|
||||
)
|
||||
self.device = device
|
||||
self.max_length = max_length # TODO: typical value?
|
||||
self.max_length = max_length # TODO: typical value?
|
||||
self.freeze()
|
||||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
#self.train = disabled_train
|
||||
# self.train = disabled_train
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding='max_length',
|
||||
return_tensors='pt')
|
||||
tokens = batch_encoding['input_ids'].to(self.device)
|
||||
outputs = self.transformer(input_ids=tokens)
|
||||
|
||||
z = outputs.last_hidden_state
|
||||
@@ -210,10 +277,9 @@ class FrozenT5Embedder(AbstractEncoder):
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
from modelscope.models.cv.image_to_3d.ldm.thirdp.psp.id_loss import IDFeatures
|
||||
import kornia.augmentation as K
|
||||
|
||||
class FrozenFaceEncoder(AbstractEncoder):
|
||||
|
||||
def __init__(self, model_path, augment=False):
|
||||
super().__init__()
|
||||
self.loss_fn = IDFeatures(model_path)
|
||||
@@ -242,8 +308,8 @@ class FrozenFaceEncoder(AbstractEncoder):
|
||||
|
||||
if self.augment is not None:
|
||||
# Transforms require 0-1
|
||||
img = self.augment((img + 1)/2)
|
||||
img = 2*img - 1
|
||||
img = self.augment((img + 1) / 2)
|
||||
img = 2 * img - 1
|
||||
|
||||
feat = self.loss_fn(img, crop=True)
|
||||
feat = self.mapper(feat.unsqueeze(1))
|
||||
@@ -252,26 +318,43 @@ class FrozenFaceEncoder(AbstractEncoder):
|
||||
def encode(self, img):
|
||||
return self(img)
|
||||
|
||||
|
||||
class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): # clip-vit-base-patch32
|
||||
|
||||
def __init__(self,
|
||||
version='openai/clip-vit-large-patch14',
|
||||
device='cuda',
|
||||
max_length=77): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(version, cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models')
|
||||
self.transformer = CLIPTextModel.from_pretrained(version, cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models')
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(
|
||||
version,
|
||||
cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models'
|
||||
)
|
||||
self.transformer = CLIPTextModel.from_pretrained(
|
||||
version,
|
||||
cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models'
|
||||
)
|
||||
self.device = device
|
||||
self.max_length = max_length # TODO: typical value?
|
||||
self.max_length = max_length # TODO: typical value?
|
||||
self.freeze()
|
||||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
#self.train = disabled_train
|
||||
# self.train = disabled_train
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding='max_length',
|
||||
return_tensors='pt')
|
||||
tokens = batch_encoding['input_ids'].to(self.device)
|
||||
outputs = self.transformer(input_ids=tokens)
|
||||
|
||||
z = outputs.last_hidden_state
|
||||
@@ -280,36 +363,47 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPVisionModel
|
||||
|
||||
class ClipImageProjector(AbstractEncoder):
|
||||
"""
|
||||
Uses the CLIP image encoder.
|
||||
"""
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", max_length=77): # clip-vit-base-patch32
|
||||
|
||||
def __init__(self,
|
||||
version='openai/clip-vit-large-patch14',
|
||||
max_length=77): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
self.model = CLIPVisionModel.from_pretrained(version)
|
||||
self.model.train()
|
||||
self.max_length = max_length # TODO: typical value?
|
||||
self.max_length = max_length # TODO: typical value?
|
||||
self.antialias = True
|
||||
self.mapper = torch.nn.Linear(1024, 768)
|
||||
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
|
||||
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
|
||||
self.register_buffer(
|
||||
'mean',
|
||||
torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
|
||||
persistent=False)
|
||||
self.register_buffer(
|
||||
'std',
|
||||
torch.Tensor([0.26862954, 0.26130258, 0.27577711]),
|
||||
persistent=False)
|
||||
null_cond = self.get_null_cond(version, max_length)
|
||||
self.register_buffer('null_cond', null_cond)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_null_cond(self, version, max_length):
|
||||
device = self.mean.device
|
||||
embedder = FrozenCLIPEmbedder(version=version, device=device, max_length=max_length)
|
||||
null_cond = embedder([""])
|
||||
embedder = FrozenCLIPEmbedder(
|
||||
version=version, device=device, max_length=max_length)
|
||||
null_cond = embedder([''])
|
||||
return null_cond
|
||||
|
||||
def preprocess(self, x):
|
||||
# Expects inputs in the range -1, 1
|
||||
x = kornia.geometry.resize(x, (224, 224),
|
||||
interpolation='bicubic',align_corners=True,
|
||||
antialias=self.antialias)
|
||||
x = kornia.geometry.resize(
|
||||
x, (224, 224),
|
||||
interpolation='bicubic',
|
||||
align_corners=True,
|
||||
antialias=self.antialias)
|
||||
x = (x + 1.) / 2.
|
||||
# renormalize according to clip
|
||||
x = kornia.enhance.normalize(x, self.mean, self.std)
|
||||
@@ -323,15 +417,23 @@ class ClipImageProjector(AbstractEncoder):
|
||||
outputs = self.model(pixel_values=x)
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
last_hidden_state = self.mapper(last_hidden_state)
|
||||
return F.pad(last_hidden_state, [0,0, 0,self.max_length-last_hidden_state.shape[1], 0,0])
|
||||
return F.pad(
|
||||
last_hidden_state,
|
||||
[0, 0, 0, self.max_length - last_hidden_state.shape[1], 0, 0])
|
||||
|
||||
def encode(self, im):
|
||||
return self(im)
|
||||
|
||||
|
||||
class ProjectedFrozenCLIPEmbedder(AbstractEncoder):
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): # clip-vit-base-patch32
|
||||
|
||||
def __init__(self,
|
||||
version='openai/clip-vit-large-patch14',
|
||||
device='cuda',
|
||||
max_length=77): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
self.embedder = FrozenCLIPEmbedder(version=version, device=device, max_length=max_length)
|
||||
self.embedder = FrozenCLIPEmbedder(
|
||||
version=version, device=device, max_length=max_length)
|
||||
self.projection = torch.nn.Linear(768, 768)
|
||||
|
||||
def forward(self, text):
|
||||
@@ -341,31 +443,41 @@ class ProjectedFrozenCLIPEmbedder(AbstractEncoder):
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
|
||||
class FrozenCLIPImageEmbedder(AbstractEncoder):
|
||||
"""
|
||||
Uses the CLIP image encoder.
|
||||
Not actually frozen... If you want that set cond_stage_trainable=False in cfg
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model='ViT-L/14',
|
||||
jit=False,
|
||||
device='cpu',
|
||||
antialias=False,
|
||||
):
|
||||
self,
|
||||
model='ViT-L/14',
|
||||
jit=False,
|
||||
device='cpu',
|
||||
antialias=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.model, _ = clip.load(name=model, device=device, jit=jit)
|
||||
# We don't use the text part so delete it
|
||||
del self.model.transformer
|
||||
self.antialias = antialias
|
||||
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
|
||||
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
|
||||
self.register_buffer(
|
||||
'mean',
|
||||
torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
|
||||
persistent=False)
|
||||
self.register_buffer(
|
||||
'std',
|
||||
torch.Tensor([0.26862954, 0.26130258, 0.27577711]),
|
||||
persistent=False)
|
||||
|
||||
def preprocess(self, x):
|
||||
# Expects inputs in the range -1, 1
|
||||
x = kornia.geometry.resize(x, (224, 224),
|
||||
interpolation='bicubic',align_corners=True,
|
||||
antialias=self.antialias)
|
||||
x = kornia.geometry.resize(
|
||||
x, (224, 224),
|
||||
interpolation='bicubic',
|
||||
align_corners=True,
|
||||
antialias=self.antialias)
|
||||
x = (x + 1.) / 2.
|
||||
# renormalize according to clip
|
||||
x = kornia.enhance.normalize(x, self.mean, self.std)
|
||||
@@ -382,35 +494,41 @@ class FrozenCLIPImageEmbedder(AbstractEncoder):
|
||||
def encode(self, im):
|
||||
return self(im).unsqueeze(1)
|
||||
|
||||
from torchvision import transforms
|
||||
import random
|
||||
|
||||
class FrozenCLIPImageMutliEmbedder(AbstractEncoder):
|
||||
"""
|
||||
Uses the CLIP image encoder.
|
||||
Not actually frozen... If you want that set cond_stage_trainable=False in cfg
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model='ViT-L/14',
|
||||
jit=False,
|
||||
device='cpu',
|
||||
antialias=True,
|
||||
max_crops=5,
|
||||
):
|
||||
self,
|
||||
model='ViT-L/14',
|
||||
jit=False,
|
||||
device='cpu',
|
||||
antialias=True,
|
||||
max_crops=5,
|
||||
):
|
||||
super().__init__()
|
||||
self.model, _ = clip.load(name=model, device=device, jit=jit)
|
||||
# We don't use the text part so delete it
|
||||
del self.model.transformer
|
||||
self.antialias = antialias
|
||||
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
|
||||
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
|
||||
self.register_buffer(
|
||||
'mean',
|
||||
torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
|
||||
persistent=False)
|
||||
self.register_buffer(
|
||||
'std',
|
||||
torch.Tensor([0.26862954, 0.26130258, 0.27577711]),
|
||||
persistent=False)
|
||||
self.max_crops = max_crops
|
||||
|
||||
def preprocess(self, x):
|
||||
|
||||
# Expects inputs in the range -1, 1
|
||||
randcrop = transforms.RandomResizedCrop(224, scale=(0.085, 1.0), ratio=(1,1))
|
||||
randcrop = transforms.RandomResizedCrop(
|
||||
224, scale=(0.085, 1.0), ratio=(1, 1))
|
||||
max_crops = self.max_crops
|
||||
patches = []
|
||||
crops = [randcrop(x) for _ in range(max_crops)]
|
||||
@@ -441,7 +559,9 @@ class FrozenCLIPImageMutliEmbedder(AbstractEncoder):
|
||||
def encode(self, im):
|
||||
return self(im)
|
||||
|
||||
|
||||
class SpatialRescaler(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
n_stages=1,
|
||||
method='bilinear',
|
||||
@@ -452,19 +572,24 @@ class SpatialRescaler(nn.Module):
|
||||
super().__init__()
|
||||
self.n_stages = n_stages
|
||||
assert self.n_stages >= 0
|
||||
assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
|
||||
assert method in [
|
||||
'nearest', 'linear', 'bilinear', 'trilinear', 'bicubic', 'area'
|
||||
]
|
||||
self.multiplier = multiplier
|
||||
self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
|
||||
self.interpolator = partial(
|
||||
torch.nn.functional.interpolate, mode=method)
|
||||
self.remap_output = out_channels is not None
|
||||
if self.remap_output:
|
||||
print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
|
||||
self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
|
||||
print(
|
||||
f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.'
|
||||
)
|
||||
self.channel_mapper = nn.Conv2d(
|
||||
in_channels, out_channels, 1, bias=bias)
|
||||
|
||||
def forward(self,x):
|
||||
def forward(self, x):
|
||||
for stage in range(self.n_stages):
|
||||
x = self.interpolator(x, scale_factor=self.multiplier)
|
||||
|
||||
|
||||
if self.remap_output:
|
||||
x = self.channel_mapper(x)
|
||||
return x
|
||||
@@ -473,25 +598,38 @@ class SpatialRescaler(nn.Module):
|
||||
return self(x)
|
||||
|
||||
|
||||
from modelscope.models.cv.image_to_3d.ldm.util import instantiate_from_config
|
||||
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
|
||||
|
||||
|
||||
class LowScaleEncoder(nn.Module):
|
||||
def __init__(self, model_config, linear_start, linear_end, timesteps=1000, max_noise_level=250, output_size=64,
|
||||
|
||||
def __init__(self,
|
||||
model_config,
|
||||
linear_start,
|
||||
linear_end,
|
||||
timesteps=1000,
|
||||
max_noise_level=250,
|
||||
output_size=64,
|
||||
scale_factor=1.0):
|
||||
super().__init__()
|
||||
self.max_noise_level = max_noise_level
|
||||
self.model = instantiate_from_config(model_config)
|
||||
self.augmentation_schedule = self.register_schedule(timesteps=timesteps, linear_start=linear_start,
|
||||
linear_end=linear_end)
|
||||
self.augmentation_schedule = self.register_schedule(
|
||||
timesteps=timesteps,
|
||||
linear_start=linear_start,
|
||||
linear_end=linear_end)
|
||||
self.out_size = output_size
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
def register_schedule(self, beta_schedule="linear", timesteps=1000,
|
||||
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
|
||||
cosine_s=cosine_s)
|
||||
def register_schedule(self,
|
||||
beta_schedule='linear',
|
||||
timesteps=1000,
|
||||
linear_start=1e-4,
|
||||
linear_end=2e-2,
|
||||
cosine_s=8e-3):
|
||||
betas = make_beta_schedule(
|
||||
beta_schedule,
|
||||
timesteps,
|
||||
linear_start=linear_start,
|
||||
linear_end=linear_end,
|
||||
cosine_s=cosine_s)
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
||||
@@ -500,33 +638,45 @@ class LowScaleEncoder(nn.Module):
|
||||
self.num_timesteps = int(timesteps)
|
||||
self.linear_start = linear_start
|
||||
self.linear_end = linear_end
|
||||
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
|
||||
assert alphas_cumprod.shape[
|
||||
0] == self.num_timesteps, 'alphas have to be defined for each timestep'
|
||||
|
||||
to_torch = partial(torch.tensor, dtype=torch.float32)
|
||||
|
||||
self.register_buffer('betas', to_torch(betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
|
||||
self.register_buffer('alphas_cumprod_prev',
|
||||
to_torch(alphas_cumprod_prev))
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
|
||||
self.register_buffer('sqrt_alphas_cumprod',
|
||||
to_torch(np.sqrt(alphas_cumprod)))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod',
|
||||
to_torch(np.sqrt(1. - alphas_cumprod)))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod',
|
||||
to_torch(np.log(1. - alphas_cumprod)))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod',
|
||||
to_torch(np.sqrt(1. / alphas_cumprod)))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod',
|
||||
to_torch(np.sqrt(1. / alphas_cumprod - 1)))
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
||||
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
|
||||
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape)
|
||||
* x_start
|
||||
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t,
|
||||
x_start.shape) * noise)
|
||||
|
||||
def forward(self, x):
|
||||
z = self.model.encode(x).sample()
|
||||
z = z * self.scale_factor
|
||||
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
|
||||
noise_level = torch.randint(
|
||||
0, self.max_noise_level, (x.shape[0], ), device=x.device).long()
|
||||
z = self.q_sample(z, noise_level)
|
||||
if self.out_size is not None:
|
||||
z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest") # TODO: experiment with mode
|
||||
z = torch.nn.functional.interpolate(
|
||||
z, size=self.out_size,
|
||||
mode='nearest') # TODO: experiment with mode
|
||||
# z = z.repeat_interleave(2, -2).repeat_interleave(2, -1)
|
||||
return z, noise_level
|
||||
|
||||
@@ -535,10 +685,13 @@ class LowScaleEncoder(nn.Module):
|
||||
return self.model.decode(z)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if __name__ == '__main__':
|
||||
from ldm.util import count_params
|
||||
sentences = ["a hedgehog drinking a whiskey", "der mond ist aufgegangen", "Ein Satz mit vielen Sonderzeichen: äöü ß ?! : 'xx-y/@s'"]
|
||||
model = FrozenT5Embedder(version="google/t5-v1_1-xl").cuda()
|
||||
sentences = [
|
||||
'a hedgehog drinking a whiskey', 'der mond ist aufgegangen',
|
||||
"Ein Satz mit vielen Sonderzeichen: äöü ß ?! : 'xx-y/@s'"
|
||||
]
|
||||
model = FrozenT5Embedder(version='google/t5-v1_1-xl').cuda()
|
||||
count_params(model, True)
|
||||
z = model(sentences)
|
||||
print(z.shape)
|
||||
@@ -548,4 +701,4 @@ if __name__ == "__main__":
|
||||
z = model(sentences)
|
||||
print(z.shape)
|
||||
|
||||
print("done.")
|
||||
print('done.')
|
||||
|
||||
@@ -1,28 +1,26 @@
|
||||
"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers"""
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
from collections import namedtuple
|
||||
from functools import partial
|
||||
from inspect import isfunction
|
||||
from collections import namedtuple
|
||||
from einops import rearrange, repeat, reduce
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, reduce, repeat
|
||||
from torch import einsum, nn
|
||||
|
||||
# constants
|
||||
|
||||
DEFAULT_DIM_HEAD = 64
|
||||
|
||||
Intermediates = namedtuple('Intermediates', [
|
||||
'pre_softmax_attn',
|
||||
'post_softmax_attn'
|
||||
])
|
||||
Intermediates = namedtuple('Intermediates',
|
||||
['pre_softmax_attn', 'post_softmax_attn'])
|
||||
|
||||
LayerIntermediates = namedtuple('Intermediates', [
|
||||
'hiddens',
|
||||
'attn_intermediates'
|
||||
])
|
||||
LayerIntermediates = namedtuple('Intermediates',
|
||||
['hiddens', 'attn_intermediates'])
|
||||
|
||||
|
||||
class AbsolutePositionalEmbedding(nn.Module):
|
||||
|
||||
def __init__(self, dim, max_seq_len):
|
||||
super().__init__()
|
||||
self.emb = nn.Embedding(max_seq_len, dim)
|
||||
@@ -37,13 +35,15 @@ class AbsolutePositionalEmbedding(nn.Module):
|
||||
|
||||
|
||||
class FixedPositionalEmbedding(nn.Module):
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
||||
inv_freq = 1. / (10000**(torch.arange(0, dim, 2).float() / dim))
|
||||
self.register_buffer('inv_freq', inv_freq)
|
||||
|
||||
def forward(self, x, seq_dim=1, offset=0):
|
||||
t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
|
||||
t = torch.arange(
|
||||
x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
|
||||
sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
|
||||
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
|
||||
return emb[None, :, :]
|
||||
@@ -51,6 +51,7 @@ class FixedPositionalEmbedding(nn.Module):
|
||||
|
||||
# helpers
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
@@ -62,20 +63,26 @@ def default(val, d):
|
||||
|
||||
|
||||
def always(val):
|
||||
|
||||
def inner(*args, **kwargs):
|
||||
return val
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def not_equals(val):
|
||||
|
||||
def inner(x):
|
||||
return x != val
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def equals(val):
|
||||
|
||||
def inner(x):
|
||||
return x == val
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
@@ -85,6 +92,7 @@ def max_neg_value(tensor):
|
||||
|
||||
# keyword argument helpers
|
||||
|
||||
|
||||
def pick_and_pop(keys, d):
|
||||
values = list(map(lambda key: d.pop(key), keys))
|
||||
return dict(zip(keys, values))
|
||||
@@ -96,7 +104,7 @@ def group_dict_by_key(cond, d):
|
||||
match = bool(cond(key))
|
||||
ind = int(not match)
|
||||
return_val[ind][key] = d[key]
|
||||
return (*return_val,)
|
||||
return (*return_val, )
|
||||
|
||||
|
||||
def string_begins_with(prefix, str):
|
||||
@@ -108,13 +116,17 @@ def group_by_key_prefix(prefix, d):
|
||||
|
||||
|
||||
def groupby_prefix_and_trim(prefix, d):
|
||||
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
|
||||
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
||||
kwargs_with_prefix, kwargs = group_dict_by_key(
|
||||
partial(string_begins_with, prefix), d)
|
||||
kwargs_without_prefix = dict(
|
||||
map(lambda x: (x[0][len(prefix):], x[1]),
|
||||
tuple(kwargs_with_prefix.items())))
|
||||
return kwargs_without_prefix, kwargs
|
||||
|
||||
|
||||
# classes
|
||||
class Scale(nn.Module):
|
||||
|
||||
def __init__(self, value, fn):
|
||||
super().__init__()
|
||||
self.value = value
|
||||
@@ -126,6 +138,7 @@ class Scale(nn.Module):
|
||||
|
||||
|
||||
class Rezero(nn.Module):
|
||||
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
@@ -137,9 +150,10 @@ class Rezero(nn.Module):
|
||||
|
||||
|
||||
class ScaleNorm(nn.Module):
|
||||
|
||||
def __init__(self, dim, eps=1e-5):
|
||||
super().__init__()
|
||||
self.scale = dim ** -0.5
|
||||
self.scale = dim**-0.5
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(1))
|
||||
|
||||
@@ -149,9 +163,10 @@ class ScaleNorm(nn.Module):
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
|
||||
def __init__(self, dim, eps=1e-8):
|
||||
super().__init__()
|
||||
self.scale = dim ** -0.5
|
||||
self.scale = dim**-0.5
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(dim))
|
||||
|
||||
@@ -161,11 +176,13 @@ class RMSNorm(nn.Module):
|
||||
|
||||
|
||||
class Residual(nn.Module):
|
||||
|
||||
def forward(self, x, residual):
|
||||
return x + residual
|
||||
|
||||
|
||||
class GRUGating(nn.Module):
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.gru = nn.GRUCell(dim, dim)
|
||||
@@ -173,15 +190,16 @@ class GRUGating(nn.Module):
|
||||
def forward(self, x, residual):
|
||||
gated_output = self.gru(
|
||||
rearrange(x, 'b n d -> (b n) d'),
|
||||
rearrange(residual, 'b n d -> (b n) d')
|
||||
)
|
||||
rearrange(residual, 'b n d -> (b n) d'))
|
||||
|
||||
return gated_output.reshape_as(x)
|
||||
|
||||
|
||||
# feedforward
|
||||
|
||||
|
||||
class GEGLU(nn.Module):
|
||||
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
@@ -192,20 +210,16 @@ class GEGLU(nn.Module):
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(
|
||||
nn.Linear(dim, inner_dim),
|
||||
nn.GELU()
|
||||
) if not glu else GEGLU(dim, inner_dim)
|
||||
project_in = nn.Sequential(nn.Linear(
|
||||
dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in,
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim_out)
|
||||
)
|
||||
self.net = nn.Sequential(project_in, nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim_out))
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
@@ -213,24 +227,24 @@ class FeedForward(nn.Module):
|
||||
|
||||
# attention.
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_head=DEFAULT_DIM_HEAD,
|
||||
heads=8,
|
||||
causal=False,
|
||||
mask=None,
|
||||
talking_heads=False,
|
||||
sparse_topk=None,
|
||||
use_entmax15=False,
|
||||
num_mem_kv=0,
|
||||
dropout=0.,
|
||||
on_attn=False
|
||||
):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
dim_head=DEFAULT_DIM_HEAD,
|
||||
heads=8,
|
||||
causal=False,
|
||||
mask=None,
|
||||
talking_heads=False,
|
||||
sparse_topk=None,
|
||||
use_entmax15=False,
|
||||
num_mem_kv=0,
|
||||
dropout=0.,
|
||||
on_attn=False):
|
||||
super().__init__()
|
||||
if use_entmax15:
|
||||
raise NotImplementedError("Check out entmax activation instead of softmax activation!")
|
||||
self.scale = dim_head ** -0.5
|
||||
raise NotImplementedError(
|
||||
'Check out entmax activation instead of softmax activation!')
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = heads
|
||||
self.causal = causal
|
||||
self.mask = mask
|
||||
@@ -252,7 +266,7 @@ class Attention(nn.Module):
|
||||
self.sparse_topk = sparse_topk
|
||||
|
||||
# entmax
|
||||
#self.attn_fn = entmax15 if use_entmax15 else F.softmax
|
||||
# self.attn_fn = entmax15 if use_entmax15 else F.softmax
|
||||
self.attn_fn = F.softmax
|
||||
|
||||
# add memory key / values
|
||||
@@ -263,19 +277,19 @@ class Attention(nn.Module):
|
||||
|
||||
# attention on attention
|
||||
self.attn_on_attn = on_attn
|
||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim)
|
||||
self.to_out = nn.Sequential(nn.Linear(
|
||||
inner_dim, dim
|
||||
* 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
context=None,
|
||||
mask=None,
|
||||
context_mask=None,
|
||||
rel_pos=None,
|
||||
sinusoidal_emb=None,
|
||||
prev_attn=None,
|
||||
mem=None
|
||||
):
|
||||
def forward(self,
|
||||
x,
|
||||
context=None,
|
||||
mask=None,
|
||||
context_mask=None,
|
||||
rel_pos=None,
|
||||
sinusoidal_emb=None,
|
||||
prev_attn=None,
|
||||
mem=None):
|
||||
b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device
|
||||
kv_input = default(context, x)
|
||||
|
||||
@@ -297,23 +311,29 @@ class Attention(nn.Module):
|
||||
k = self.to_k(k_input)
|
||||
v = self.to_v(v_input)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h),
|
||||
(q, k, v))
|
||||
|
||||
input_mask = None
|
||||
if any(map(exists, (mask, context_mask))):
|
||||
q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
|
||||
q_mask = default(mask, lambda: torch.ones(
|
||||
(b, n), device=device).bool())
|
||||
k_mask = q_mask if not exists(context) else context_mask
|
||||
k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool())
|
||||
k_mask = default(
|
||||
k_mask, lambda: torch.ones(
|
||||
(b, k.shape[-2]), device=device).bool())
|
||||
q_mask = rearrange(q_mask, 'b i -> b () i ()')
|
||||
k_mask = rearrange(k_mask, 'b j -> b () () j')
|
||||
input_mask = q_mask * k_mask
|
||||
|
||||
if self.num_mem_kv > 0:
|
||||
mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v))
|
||||
mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b),
|
||||
(self.mem_k, self.mem_v))
|
||||
k = torch.cat((mem_k, k), dim=-2)
|
||||
v = torch.cat((mem_v, v), dim=-2)
|
||||
if exists(input_mask):
|
||||
input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
|
||||
input_mask = F.pad(
|
||||
input_mask, (self.num_mem_kv, 0), value=True)
|
||||
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
mask_value = max_neg_value(dots)
|
||||
@@ -324,7 +344,8 @@ class Attention(nn.Module):
|
||||
pre_softmax_attn = dots
|
||||
|
||||
if talking_heads:
|
||||
dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()
|
||||
dots = einsum('b h i j, h k -> b k i j', dots,
|
||||
self.pre_softmax_proj).contiguous()
|
||||
|
||||
if exists(rel_pos):
|
||||
dots = rel_pos(dots)
|
||||
@@ -336,7 +357,8 @@ class Attention(nn.Module):
|
||||
if self.causal:
|
||||
i, j = dots.shape[-2:]
|
||||
r = torch.arange(i, device=device)
|
||||
mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j')
|
||||
mask = rearrange(r, 'i -> () () i ()') < rearrange(
|
||||
r, 'j -> () () () j')
|
||||
mask = F.pad(mask, (j - i, 0), value=False)
|
||||
dots.masked_fill_(mask, mask_value)
|
||||
del mask
|
||||
@@ -354,59 +376,60 @@ class Attention(nn.Module):
|
||||
attn = self.dropout(attn)
|
||||
|
||||
if talking_heads:
|
||||
attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous()
|
||||
attn = einsum('b h i j, h k -> b k i j', attn,
|
||||
self.post_softmax_proj).contiguous()
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
|
||||
intermediates = Intermediates(
|
||||
pre_softmax_attn=pre_softmax_attn,
|
||||
post_softmax_attn=post_softmax_attn
|
||||
)
|
||||
post_softmax_attn=post_softmax_attn)
|
||||
|
||||
return self.to_out(out), intermediates
|
||||
|
||||
|
||||
class AttentionLayers(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
depth,
|
||||
heads=8,
|
||||
causal=False,
|
||||
cross_attend=False,
|
||||
only_cross=False,
|
||||
use_scalenorm=False,
|
||||
use_rmsnorm=False,
|
||||
use_rezero=False,
|
||||
rel_pos_num_buckets=32,
|
||||
rel_pos_max_distance=128,
|
||||
position_infused_attn=False,
|
||||
custom_layers=None,
|
||||
sandwich_coef=None,
|
||||
par_ratio=None,
|
||||
residual_attn=False,
|
||||
cross_residual_attn=False,
|
||||
macaron=False,
|
||||
pre_norm=True,
|
||||
gate_residual=False,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
depth,
|
||||
heads=8,
|
||||
causal=False,
|
||||
cross_attend=False,
|
||||
only_cross=False,
|
||||
use_scalenorm=False,
|
||||
use_rmsnorm=False,
|
||||
use_rezero=False,
|
||||
rel_pos_num_buckets=32,
|
||||
rel_pos_max_distance=128,
|
||||
position_infused_attn=False,
|
||||
custom_layers=None,
|
||||
sandwich_coef=None,
|
||||
par_ratio=None,
|
||||
residual_attn=False,
|
||||
cross_residual_attn=False,
|
||||
macaron=False,
|
||||
pre_norm=True,
|
||||
gate_residual=False,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
|
||||
attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)
|
||||
|
||||
dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
|
||||
# dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
|
||||
|
||||
self.dim = dim
|
||||
self.depth = depth
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
self.has_pos_emb = position_infused_attn
|
||||
self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
|
||||
self.pia_pos_emb = FixedPositionalEmbedding(
|
||||
dim) if position_infused_attn else None
|
||||
self.rotary_pos_emb = always(None)
|
||||
|
||||
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
|
||||
assert rel_pos_num_buckets <= rel_pos_max_distance, \
|
||||
'number of relative position buckets must be less than the relative position max distance'
|
||||
self.rel_pos = None
|
||||
|
||||
self.pre_norm = pre_norm
|
||||
@@ -429,7 +452,7 @@ class AttentionLayers(nn.Module):
|
||||
default_block = ('a', 'f')
|
||||
|
||||
if macaron:
|
||||
default_block = ('f',) + default_block
|
||||
default_block = ('f', ) + default_block
|
||||
|
||||
if exists(custom_layers):
|
||||
layer_types = custom_layers
|
||||
@@ -440,13 +463,17 @@ class AttentionLayers(nn.Module):
|
||||
par_attn = par_depth // par_ratio
|
||||
depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
|
||||
par_width = (depth_cut + depth_cut // par_attn) // par_attn
|
||||
assert len(default_block) <= par_width, 'default block is too large for par_ratio'
|
||||
par_block = default_block + ('f',) * (par_width - len(default_block))
|
||||
assert len(
|
||||
default_block
|
||||
) <= par_width, 'default block is too large for par_ratio'
|
||||
par_block = default_block + ('f', ) * (
|
||||
par_width - len(default_block))
|
||||
par_head = par_block * par_attn
|
||||
layer_types = par_head + ('f',) * (par_depth - len(par_head))
|
||||
layer_types = par_head + ('f', ) * (par_depth - len(par_head))
|
||||
elif exists(sandwich_coef):
|
||||
assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
|
||||
layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
|
||||
layer_types = ('a', ) * sandwich_coef + default_block * (
|
||||
depth - sandwich_coef) + ('f', ) * sandwich_coef
|
||||
else:
|
||||
layer_types = default_block * depth
|
||||
|
||||
@@ -455,7 +482,8 @@ class AttentionLayers(nn.Module):
|
||||
|
||||
for layer_type in self.layer_types:
|
||||
if layer_type == 'a':
|
||||
layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
|
||||
layer = Attention(
|
||||
dim, heads=heads, causal=causal, **attn_kwargs)
|
||||
elif layer_type == 'c':
|
||||
layer = Attention(dim, heads=heads, **attn_kwargs)
|
||||
elif layer_type == 'f':
|
||||
@@ -472,21 +500,15 @@ class AttentionLayers(nn.Module):
|
||||
else:
|
||||
residual_fn = Residual()
|
||||
|
||||
self.layers.append(nn.ModuleList([
|
||||
norm_fn(),
|
||||
layer,
|
||||
residual_fn
|
||||
]))
|
||||
self.layers.append(nn.ModuleList([norm_fn(), layer, residual_fn]))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
context=None,
|
||||
mask=None,
|
||||
context_mask=None,
|
||||
mems=None,
|
||||
return_hiddens=False
|
||||
):
|
||||
def forward(self,
|
||||
x,
|
||||
context=None,
|
||||
mask=None,
|
||||
context_mask=None,
|
||||
mems=None,
|
||||
return_hiddens=False):
|
||||
hiddens = []
|
||||
intermediates = []
|
||||
prev_attn = None
|
||||
@@ -494,7 +516,8 @@ class AttentionLayers(nn.Module):
|
||||
|
||||
mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
|
||||
|
||||
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
|
||||
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(
|
||||
zip(self.layer_types, self.layers)):
|
||||
is_last = ind == (len(self.layers) - 1)
|
||||
|
||||
if layer_type == 'a':
|
||||
@@ -507,10 +530,20 @@ class AttentionLayers(nn.Module):
|
||||
x = norm(x)
|
||||
|
||||
if layer_type == 'a':
|
||||
out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos,
|
||||
prev_attn=prev_attn, mem=layer_mem)
|
||||
out, inter = block(
|
||||
x,
|
||||
mask=mask,
|
||||
sinusoidal_emb=self.pia_pos_emb,
|
||||
rel_pos=self.rel_pos,
|
||||
prev_attn=prev_attn,
|
||||
mem=layer_mem)
|
||||
elif layer_type == 'c':
|
||||
out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn)
|
||||
out, inter = block(
|
||||
x,
|
||||
context=context,
|
||||
mask=mask,
|
||||
context_mask=context_mask,
|
||||
prev_attn=prev_cross_attn)
|
||||
elif layer_type == 'f':
|
||||
out = block(x)
|
||||
|
||||
@@ -529,9 +562,7 @@ class AttentionLayers(nn.Module):
|
||||
|
||||
if return_hiddens:
|
||||
intermediates = LayerIntermediates(
|
||||
hiddens=hiddens,
|
||||
attn_intermediates=intermediates
|
||||
)
|
||||
hiddens=hiddens, attn_intermediates=intermediates)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
@@ -539,28 +570,29 @@ class AttentionLayers(nn.Module):
|
||||
|
||||
|
||||
class Encoder(AttentionLayers):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
assert 'causal' not in kwargs, 'cannot set causality on encoder'
|
||||
super().__init__(causal=False, **kwargs)
|
||||
|
||||
|
||||
|
||||
class TransformerWrapper(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
num_tokens,
|
||||
max_seq_len,
|
||||
attn_layers,
|
||||
emb_dim=None,
|
||||
max_mem_len=0.,
|
||||
emb_dropout=0.,
|
||||
num_memory_tokens=None,
|
||||
tie_embedding=False,
|
||||
use_pos_emb=True
|
||||
):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
num_tokens,
|
||||
max_seq_len,
|
||||
attn_layers,
|
||||
emb_dim=None,
|
||||
max_mem_len=0.,
|
||||
emb_dropout=0.,
|
||||
num_memory_tokens=None,
|
||||
tie_embedding=False,
|
||||
use_pos_emb=True):
|
||||
super().__init__()
|
||||
assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
|
||||
assert isinstance(
|
||||
attn_layers, AttentionLayers
|
||||
), 'attention layers must be one of Encoder or Decoder'
|
||||
|
||||
dim = attn_layers.dim
|
||||
emb_dim = default(emb_dim, dim)
|
||||
@@ -571,22 +603,26 @@ class TransformerWrapper(nn.Module):
|
||||
|
||||
self.token_emb = nn.Embedding(num_tokens, emb_dim)
|
||||
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
|
||||
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
|
||||
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
|
||||
self.emb_dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
|
||||
self.project_emb = nn.Linear(emb_dim,
|
||||
dim) if emb_dim != dim else nn.Identity()
|
||||
self.attn_layers = attn_layers
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.init_()
|
||||
|
||||
self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
|
||||
self.to_logits = nn.Linear(
|
||||
dim, num_tokens
|
||||
) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
|
||||
|
||||
# memory tokens (like [cls]) from Memory Transformers paper
|
||||
num_memory_tokens = default(num_memory_tokens, 0)
|
||||
self.num_memory_tokens = num_memory_tokens
|
||||
if num_memory_tokens > 0:
|
||||
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
|
||||
self.memory_tokens = nn.Parameter(
|
||||
torch.randn(num_memory_tokens, dim))
|
||||
|
||||
# let funnel encoder know number of memory tokens, if specified
|
||||
if hasattr(attn_layers, 'num_memory_tokens'):
|
||||
@@ -595,17 +631,16 @@ class TransformerWrapper(nn.Module):
|
||||
def init_(self):
|
||||
nn.init.normal_(self.token_emb.weight, std=0.02)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
return_embeddings=False,
|
||||
mask=None,
|
||||
return_mems=False,
|
||||
return_attn=False,
|
||||
mems=None,
|
||||
**kwargs
|
||||
):
|
||||
b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
|
||||
def forward(self,
|
||||
x,
|
||||
return_embeddings=False,
|
||||
mask=None,
|
||||
return_mems=False,
|
||||
return_attn=False,
|
||||
mems=None,
|
||||
**kwargs):
|
||||
# b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
|
||||
b, _, num_mem = *x.shape, self.num_memory_tokens
|
||||
x = self.token_emb(x)
|
||||
x += self.pos_emb(x)
|
||||
x = self.emb_dropout(x)
|
||||
@@ -620,7 +655,8 @@ class TransformerWrapper(nn.Module):
|
||||
if exists(mask):
|
||||
mask = F.pad(mask, (num_mem, 0), value=True)
|
||||
|
||||
x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
|
||||
x, intermediates = self.attn_layers(
|
||||
x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
|
||||
x = self.norm(x)
|
||||
|
||||
mem, x = x[:, :num_mem], x[:, num_mem:]
|
||||
@@ -629,13 +665,18 @@ class TransformerWrapper(nn.Module):
|
||||
|
||||
if return_mems:
|
||||
hiddens = intermediates.hiddens
|
||||
new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens
|
||||
new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
|
||||
new_mems = list(
|
||||
map(lambda pair: torch.cat(pair, dim=-2), zip(
|
||||
mems, hiddens))) if exists(mems) else hiddens
|
||||
new_mems = list(
|
||||
map(lambda t: t[..., -self.max_mem_len:, :].detach(),
|
||||
new_mems))
|
||||
return out, new_mems
|
||||
|
||||
if return_attn:
|
||||
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
|
||||
attn_maps = list(
|
||||
map(lambda t: t.post_softmax_attn,
|
||||
intermediates.attn_intermediates))
|
||||
return out, attn_maps
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@@ -1,121 +1,134 @@
|
||||
# https://github.com/eladrich/pixel2style2pixel
|
||||
|
||||
from collections import namedtuple
|
||||
import torch
|
||||
from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
|
||||
|
||||
"""
|
||||
ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
|
||||
"""
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
from torch.nn import (AdaptiveAvgPool2d, BatchNorm2d, Conv2d, MaxPool2d,
|
||||
Module, PReLU, ReLU, Sequential, Sigmoid)
|
||||
|
||||
|
||||
class Flatten(Module):
|
||||
def forward(self, input):
|
||||
return input.view(input.size(0), -1)
|
||||
|
||||
def forward(self, input):
|
||||
return input.view(input.size(0), -1)
|
||||
|
||||
|
||||
def l2_norm(input, axis=1):
|
||||
norm = torch.norm(input, 2, axis, True)
|
||||
output = torch.div(input, norm)
|
||||
return output
|
||||
norm = torch.norm(input, 2, axis, True)
|
||||
output = torch.div(input, norm)
|
||||
return output
|
||||
|
||||
|
||||
class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
|
||||
""" A named tuple describing a ResNet block. """
|
||||
""" A named tuple describing a ResNet block. """
|
||||
|
||||
|
||||
def get_block(in_channel, depth, num_units, stride=2):
|
||||
return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
|
||||
return [Bottleneck(in_channel, depth, stride)
|
||||
] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
|
||||
|
||||
|
||||
def get_blocks(num_layers):
|
||||
if num_layers == 50:
|
||||
blocks = [
|
||||
get_block(in_channel=64, depth=64, num_units=3),
|
||||
get_block(in_channel=64, depth=128, num_units=4),
|
||||
get_block(in_channel=128, depth=256, num_units=14),
|
||||
get_block(in_channel=256, depth=512, num_units=3)
|
||||
]
|
||||
elif num_layers == 100:
|
||||
blocks = [
|
||||
get_block(in_channel=64, depth=64, num_units=3),
|
||||
get_block(in_channel=64, depth=128, num_units=13),
|
||||
get_block(in_channel=128, depth=256, num_units=30),
|
||||
get_block(in_channel=256, depth=512, num_units=3)
|
||||
]
|
||||
elif num_layers == 152:
|
||||
blocks = [
|
||||
get_block(in_channel=64, depth=64, num_units=3),
|
||||
get_block(in_channel=64, depth=128, num_units=8),
|
||||
get_block(in_channel=128, depth=256, num_units=36),
|
||||
get_block(in_channel=256, depth=512, num_units=3)
|
||||
]
|
||||
else:
|
||||
raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
|
||||
return blocks
|
||||
if num_layers == 50:
|
||||
blocks = [
|
||||
get_block(in_channel=64, depth=64, num_units=3),
|
||||
get_block(in_channel=64, depth=128, num_units=4),
|
||||
get_block(in_channel=128, depth=256, num_units=14),
|
||||
get_block(in_channel=256, depth=512, num_units=3)
|
||||
]
|
||||
elif num_layers == 100:
|
||||
blocks = [
|
||||
get_block(in_channel=64, depth=64, num_units=3),
|
||||
get_block(in_channel=64, depth=128, num_units=13),
|
||||
get_block(in_channel=128, depth=256, num_units=30),
|
||||
get_block(in_channel=256, depth=512, num_units=3)
|
||||
]
|
||||
elif num_layers == 152:
|
||||
blocks = [
|
||||
get_block(in_channel=64, depth=64, num_units=3),
|
||||
get_block(in_channel=64, depth=128, num_units=8),
|
||||
get_block(in_channel=128, depth=256, num_units=36),
|
||||
get_block(in_channel=256, depth=512, num_units=3)
|
||||
]
|
||||
else:
|
||||
raise ValueError(
|
||||
'Invalid number of layers: {}. Must be one of [50, 100, 152]'.
|
||||
format(num_layers))
|
||||
return blocks
|
||||
|
||||
|
||||
class SEModule(Module):
|
||||
def __init__(self, channels, reduction):
|
||||
super(SEModule, self).__init__()
|
||||
self.avg_pool = AdaptiveAvgPool2d(1)
|
||||
self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
|
||||
self.relu = ReLU(inplace=True)
|
||||
self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
|
||||
self.sigmoid = Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
module_input = x
|
||||
x = self.avg_pool(x)
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
x = self.sigmoid(x)
|
||||
return module_input * x
|
||||
def __init__(self, channels, reduction):
|
||||
super(SEModule, self).__init__()
|
||||
self.avg_pool = AdaptiveAvgPool2d(1)
|
||||
self.fc1 = Conv2d(
|
||||
channels,
|
||||
channels // reduction,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
bias=False)
|
||||
self.relu = ReLU(inplace=True)
|
||||
self.fc2 = Conv2d(
|
||||
channels // reduction,
|
||||
channels,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
bias=False)
|
||||
self.sigmoid = Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
module_input = x
|
||||
x = self.avg_pool(x)
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
x = self.sigmoid(x)
|
||||
return module_input * x
|
||||
|
||||
|
||||
class bottleneck_IR(Module):
|
||||
def __init__(self, in_channel, depth, stride):
|
||||
super(bottleneck_IR, self).__init__()
|
||||
if in_channel == depth:
|
||||
self.shortcut_layer = MaxPool2d(1, stride)
|
||||
else:
|
||||
self.shortcut_layer = Sequential(
|
||||
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
|
||||
BatchNorm2d(depth)
|
||||
)
|
||||
self.res_layer = Sequential(
|
||||
BatchNorm2d(in_channel),
|
||||
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
|
||||
Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
shortcut = self.shortcut_layer(x)
|
||||
res = self.res_layer(x)
|
||||
return res + shortcut
|
||||
def __init__(self, in_channel, depth, stride):
|
||||
super(bottleneck_IR, self).__init__()
|
||||
if in_channel == depth:
|
||||
self.shortcut_layer = MaxPool2d(1, stride)
|
||||
else:
|
||||
self.shortcut_layer = Sequential(
|
||||
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
|
||||
BatchNorm2d(depth))
|
||||
self.res_layer = Sequential(
|
||||
BatchNorm2d(in_channel),
|
||||
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
|
||||
PReLU(depth), Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
|
||||
BatchNorm2d(depth))
|
||||
|
||||
def forward(self, x):
|
||||
shortcut = self.shortcut_layer(x)
|
||||
res = self.res_layer(x)
|
||||
return res + shortcut
|
||||
|
||||
|
||||
class bottleneck_IR_SE(Module):
|
||||
def __init__(self, in_channel, depth, stride):
|
||||
super(bottleneck_IR_SE, self).__init__()
|
||||
if in_channel == depth:
|
||||
self.shortcut_layer = MaxPool2d(1, stride)
|
||||
else:
|
||||
self.shortcut_layer = Sequential(
|
||||
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
|
||||
BatchNorm2d(depth)
|
||||
)
|
||||
self.res_layer = Sequential(
|
||||
BatchNorm2d(in_channel),
|
||||
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
|
||||
PReLU(depth),
|
||||
Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
|
||||
BatchNorm2d(depth),
|
||||
SEModule(depth, 16)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
shortcut = self.shortcut_layer(x)
|
||||
res = self.res_layer(x)
|
||||
return res + shortcut
|
||||
def __init__(self, in_channel, depth, stride):
|
||||
super(bottleneck_IR_SE, self).__init__()
|
||||
if in_channel == depth:
|
||||
self.shortcut_layer = MaxPool2d(1, stride)
|
||||
else:
|
||||
self.shortcut_layer = Sequential(
|
||||
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
|
||||
BatchNorm2d(depth))
|
||||
self.res_layer = Sequential(
|
||||
BatchNorm2d(in_channel),
|
||||
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
|
||||
PReLU(depth), Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
|
||||
BatchNorm2d(depth), SEModule(depth, 16))
|
||||
|
||||
def forward(self, x):
|
||||
shortcut = self.shortcut_layer(x)
|
||||
res = self.res_layer(x)
|
||||
return res + shortcut
|
||||
|
||||
@@ -1,22 +1,26 @@
|
||||
# https://github.com/eladrich/pixel2style2pixel
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from modelscope.models.cv.image_to_3d.ldm.thirdp.psp.model_irse import Backbone
|
||||
|
||||
|
||||
class IDFeatures(nn.Module):
|
||||
|
||||
def __init__(self, model_path):
|
||||
super(IDFeatures, self).__init__()
|
||||
print('Loading ResNet ArcFace')
|
||||
self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
|
||||
self.facenet.load_state_dict(torch.load(model_path, map_location="cpu"))
|
||||
self.facenet = Backbone(
|
||||
input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
|
||||
self.facenet.load_state_dict(
|
||||
torch.load(model_path, map_location='cpu'))
|
||||
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
|
||||
self.facenet.eval()
|
||||
|
||||
def forward(self, x, crop=False):
|
||||
# Not sure of the image range here
|
||||
if crop:
|
||||
x = torch.nn.functional.interpolate(x, (256, 256), mode="area")
|
||||
x = torch.nn.functional.interpolate(x, (256, 256), mode='area')
|
||||
x = x[:, :, 35:223, 32:220]
|
||||
x = self.face_pool(x)
|
||||
x_feats = self.facenet(x)
|
||||
|
||||
@@ -1,86 +1,97 @@
|
||||
# https://github.com/eladrich/pixel2style2pixel
|
||||
|
||||
from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
|
||||
from modelscope.models.cv.image_to_3d.ldm.thirdp.psp.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
|
||||
from torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear,
|
||||
Module, PReLU, Sequential)
|
||||
|
||||
"""
|
||||
Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
|
||||
"""
|
||||
from modelscope.models.cv.image_to_3d.ldm.thirdp.psp.helpers import (
|
||||
Flatten, bottleneck_IR, bottleneck_IR_SE, get_blocks, l2_norm)
|
||||
|
||||
|
||||
class Backbone(Module):
|
||||
def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
|
||||
super(Backbone, self).__init__()
|
||||
assert input_size in [112, 224], "input_size should be 112 or 224"
|
||||
assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
|
||||
assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
|
||||
blocks = get_blocks(num_layers)
|
||||
if mode == 'ir':
|
||||
unit_module = bottleneck_IR
|
||||
elif mode == 'ir_se':
|
||||
unit_module = bottleneck_IR_SE
|
||||
self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
|
||||
BatchNorm2d(64),
|
||||
PReLU(64))
|
||||
if input_size == 112:
|
||||
self.output_layer = Sequential(BatchNorm2d(512),
|
||||
Dropout(drop_ratio),
|
||||
Flatten(),
|
||||
Linear(512 * 7 * 7, 512),
|
||||
BatchNorm1d(512, affine=affine))
|
||||
else:
|
||||
self.output_layer = Sequential(BatchNorm2d(512),
|
||||
Dropout(drop_ratio),
|
||||
Flatten(),
|
||||
Linear(512 * 14 * 14, 512),
|
||||
BatchNorm1d(512, affine=affine))
|
||||
"""
|
||||
Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
|
||||
"""
|
||||
|
||||
modules = []
|
||||
for block in blocks:
|
||||
for bottleneck in block:
|
||||
modules.append(unit_module(bottleneck.in_channel,
|
||||
bottleneck.depth,
|
||||
bottleneck.stride))
|
||||
self.body = Sequential(*modules)
|
||||
def __init__(self,
|
||||
input_size,
|
||||
num_layers,
|
||||
mode='ir',
|
||||
drop_ratio=0.4,
|
||||
affine=True):
|
||||
super(Backbone, self).__init__()
|
||||
assert input_size in [112, 224], 'input_size should be 112 or 224'
|
||||
assert num_layers in [50, 100,
|
||||
152], 'num_layers should be 50, 100 or 152'
|
||||
assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
|
||||
blocks = get_blocks(num_layers)
|
||||
if mode == 'ir':
|
||||
unit_module = bottleneck_IR
|
||||
elif mode == 'ir_se':
|
||||
unit_module = bottleneck_IR_SE
|
||||
self.input_layer = Sequential(
|
||||
Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64),
|
||||
PReLU(64))
|
||||
if input_size == 112:
|
||||
self.output_layer = Sequential(
|
||||
BatchNorm2d(512), Dropout(drop_ratio), Flatten(),
|
||||
Linear(512 * 7 * 7, 512), BatchNorm1d(512, affine=affine))
|
||||
else:
|
||||
self.output_layer = Sequential(
|
||||
BatchNorm2d(512), Dropout(drop_ratio), Flatten(),
|
||||
Linear(512 * 14 * 14, 512), BatchNorm1d(512, affine=affine))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.input_layer(x)
|
||||
x = self.body(x)
|
||||
x = self.output_layer(x)
|
||||
return l2_norm(x)
|
||||
modules = []
|
||||
for block in blocks:
|
||||
for bottleneck in block:
|
||||
modules.append(
|
||||
unit_module(bottleneck.in_channel, bottleneck.depth,
|
||||
bottleneck.stride))
|
||||
self.body = Sequential(*modules)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.input_layer(x)
|
||||
x = self.body(x)
|
||||
x = self.output_layer(x)
|
||||
return l2_norm(x)
|
||||
|
||||
|
||||
def IR_50(input_size):
|
||||
"""Constructs a ir-50 model."""
|
||||
model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
|
||||
return model
|
||||
"""Constructs a ir-50 model."""
|
||||
model = Backbone(
|
||||
input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
|
||||
return model
|
||||
|
||||
|
||||
def IR_101(input_size):
|
||||
"""Constructs a ir-101 model."""
|
||||
model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
|
||||
return model
|
||||
"""Constructs a ir-101 model."""
|
||||
model = Backbone(
|
||||
input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
|
||||
return model
|
||||
|
||||
|
||||
def IR_152(input_size):
|
||||
"""Constructs a ir-152 model."""
|
||||
model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
|
||||
return model
|
||||
"""Constructs a ir-152 model."""
|
||||
model = Backbone(
|
||||
input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
|
||||
return model
|
||||
|
||||
|
||||
def IR_SE_50(input_size):
|
||||
"""Constructs a ir_se-50 model."""
|
||||
model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
|
||||
return model
|
||||
"""Constructs a ir_se-50 model."""
|
||||
model = Backbone(
|
||||
input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
|
||||
return model
|
||||
|
||||
|
||||
def IR_SE_101(input_size):
|
||||
"""Constructs a ir_se-101 model."""
|
||||
model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
|
||||
return model
|
||||
"""Constructs a ir_se-101 model."""
|
||||
model = Backbone(
|
||||
input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
|
||||
return model
|
||||
|
||||
|
||||
def IR_SE_152(input_size):
|
||||
"""Constructs a ir_se-152 model."""
|
||||
model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
|
||||
return model
|
||||
"""Constructs a ir_se-152 model."""
|
||||
model = Backbone(
|
||||
input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
|
||||
return model
|
||||
|
||||
@@ -1,32 +1,24 @@
|
||||
import importlib
|
||||
|
||||
import torchvision
|
||||
import torch
|
||||
from torch import optim
|
||||
import numpy as np
|
||||
|
||||
from inspect import isfunction
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from PIL import Image
|
||||
import torch
|
||||
import time
|
||||
import cv2
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from torch import optim
|
||||
|
||||
|
||||
def pil_rectangle_crop(im):
|
||||
width, height = im.size # Get dimensions
|
||||
|
||||
width, height = im.size # Get dimensions
|
||||
|
||||
if width <= height:
|
||||
left = 0
|
||||
right = width
|
||||
top = (height - width)/2
|
||||
bottom = (height + width)/2
|
||||
top = (height - width) / 2
|
||||
bottom = (height + width) / 2
|
||||
else:
|
||||
|
||||
|
||||
top = 0
|
||||
bottom = height
|
||||
left = (width - height) / 2
|
||||
@@ -36,6 +28,7 @@ def pil_rectangle_crop(im):
|
||||
im = im.crop((left, top, right, bottom))
|
||||
return im
|
||||
|
||||
|
||||
def add_margin(pil_img, color=0, size=256):
|
||||
width, height = pil_img.size
|
||||
result = Image.new(pil_img.mode, (size, size), color)
|
||||
@@ -46,16 +39,17 @@ def add_margin(pil_img, color=0, size=256):
|
||||
def create_carvekit_interface():
|
||||
from carvekit.api.high import HiInterface
|
||||
# Check doc strings for more information
|
||||
interface = HiInterface(object_type="object", # Can be "object" or "hairs-like".
|
||||
batch_size_seg=5,
|
||||
batch_size_matting=1,
|
||||
device='cuda' if torch.cuda.is_available() else 'cpu',
|
||||
seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
|
||||
matting_mask_size=2048,
|
||||
trimap_prob_threshold=231,
|
||||
trimap_dilation=30,
|
||||
trimap_erosion_iters=5,
|
||||
fp16=False)
|
||||
interface = HiInterface(
|
||||
object_type='object', # Can be "object" or "hairs-like".
|
||||
batch_size_seg=5,
|
||||
batch_size_matting=1,
|
||||
device='cuda' if torch.cuda.is_available() else 'cpu',
|
||||
seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
|
||||
matting_mask_size=2048,
|
||||
trimap_prob_threshold=231,
|
||||
trimap_dilation=30,
|
||||
trimap_erosion_iters=5,
|
||||
fp16=False)
|
||||
|
||||
return interface
|
||||
|
||||
@@ -72,17 +66,17 @@ def load_and_preprocess(interface, input_im):
|
||||
image_without_background = np.array(image_without_background)
|
||||
est_seg = image_without_background > 127
|
||||
image = np.array(image)
|
||||
foreground = est_seg[:, : , -1].astype(np.bool_)
|
||||
foreground = est_seg[:, :, -1].astype(np.bool_)
|
||||
image[~foreground] = [255., 255., 255.]
|
||||
x, y, w, h = cv2.boundingRect(foreground.astype(np.uint8))
|
||||
image = image[y:y+h, x:x+w, :]
|
||||
image = image[y:y + h, x:x + w, :]
|
||||
image = PIL.Image.fromarray(np.array(image))
|
||||
|
||||
|
||||
# resize image such that long edge is 512
|
||||
image.thumbnail([200, 200], Image.LANCZOS)
|
||||
image = add_margin(image, (255, 255, 255), size=256)
|
||||
image = np.array(image)
|
||||
|
||||
|
||||
return image
|
||||
|
||||
|
||||
@@ -92,16 +86,17 @@ def log_txt_as_img(wh, xc, size=10):
|
||||
b = len(xc)
|
||||
txts = list()
|
||||
for bi in range(b):
|
||||
txt = Image.new("RGB", wh, color="white")
|
||||
txt = Image.new('RGB', wh, color='white')
|
||||
draw = ImageDraw.Draw(txt)
|
||||
font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
|
||||
nc = int(40 * (wh[0] / 256))
|
||||
lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
|
||||
lines = '\n'.join(xc[bi][start:start + nc]
|
||||
for start in range(0, len(xc[bi]), nc))
|
||||
|
||||
try:
|
||||
draw.text((0, 0), lines, fill="black", font=font)
|
||||
draw.text((0, 0), lines, fill='black', font=font)
|
||||
except UnicodeEncodeError:
|
||||
print("Cant encode string for logging. Skipping.")
|
||||
print('Cant encode string for logging. Skipping.')
|
||||
|
||||
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
||||
txts.append(txt)
|
||||
@@ -117,7 +112,7 @@ def ismap(x):
|
||||
|
||||
|
||||
def isimage(x):
|
||||
if not isinstance(x,torch.Tensor):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return False
|
||||
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
|
||||
|
||||
@@ -143,22 +138,24 @@ def mean_flat(tensor):
|
||||
def count_params(model, verbose=False):
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
if verbose:
|
||||
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
|
||||
print(
|
||||
f'{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.'
|
||||
)
|
||||
return total_params
|
||||
|
||||
|
||||
def instantiate_from_config(config):
|
||||
if not "target" in config:
|
||||
if 'target' not in config:
|
||||
if config == '__is_first_stage__':
|
||||
return None
|
||||
elif config == "__is_unconditional__":
|
||||
elif config == '__is_unconditional__':
|
||||
return None
|
||||
raise KeyError("Expected key `target` to instantiate.")
|
||||
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
||||
raise KeyError('Expected key `target` to instantiate.')
|
||||
return get_obj_from_str(config['target'])(**config.get('params', dict()))
|
||||
|
||||
|
||||
def get_obj_from_str(string, reload=False):
|
||||
module, cls = string.rsplit(".", 1)
|
||||
module, cls = string.rsplit('.', 1)
|
||||
print(module)
|
||||
if reload:
|
||||
module_imp = importlib.import_module(module)
|
||||
@@ -168,25 +165,42 @@ def get_obj_from_str(string, reload=False):
|
||||
|
||||
class AdamWwithEMAandWings(optim.Optimizer):
|
||||
# credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
|
||||
def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
|
||||
weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
|
||||
ema_power=1., param_names=()):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1.e-3,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1.e-8, # TODO: check hyperparameters before using
|
||||
weight_decay=1.e-2,
|
||||
amsgrad=False,
|
||||
ema_decay=0.9999, # ema decay to match previous code
|
||||
ema_power=1.,
|
||||
param_names=()): # noqa
|
||||
"""AdamW that saves EMA versions of the parameters."""
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
raise ValueError('Invalid learning rate: {}'.format(lr))
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
raise ValueError('Invalid epsilon value: {}'.format(eps))
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||
raise ValueError('Invalid beta parameter at index 0: {}'.format(
|
||||
betas[0]))
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||
raise ValueError('Invalid beta parameter at index 1: {}'.format(
|
||||
betas[1]))
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
||||
raise ValueError(
|
||||
'Invalid weight_decay value: {}'.format(weight_decay))
|
||||
if not 0.0 <= ema_decay <= 1.0:
|
||||
raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps,
|
||||
weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
|
||||
ema_power=ema_power, param_names=param_names)
|
||||
raise ValueError('Invalid ema_decay value: {}'.format(ema_decay))
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
amsgrad=amsgrad,
|
||||
ema_decay=ema_decay,
|
||||
ema_power=ema_power,
|
||||
param_names=param_names)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
@@ -212,7 +226,7 @@ class AdamWwithEMAandWings(optim.Optimizer):
|
||||
exp_avgs = []
|
||||
exp_avg_sqs = []
|
||||
ema_params_with_grad = []
|
||||
state_sums = []
|
||||
# state_sums = []
|
||||
max_exp_avg_sqs = []
|
||||
state_steps = []
|
||||
amsgrad = group['amsgrad']
|
||||
@@ -225,7 +239,8 @@ class AdamWwithEMAandWings(optim.Optimizer):
|
||||
continue
|
||||
params_with_grad.append(p)
|
||||
if p.grad.is_sparse:
|
||||
raise RuntimeError('AdamW does not support sparse gradients')
|
||||
raise RuntimeError(
|
||||
'AdamW does not support sparse gradients')
|
||||
grads.append(p.grad)
|
||||
|
||||
state = self.state[p]
|
||||
@@ -234,12 +249,15 @@ class AdamWwithEMAandWings(optim.Optimizer):
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
state['exp_avg'] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format)
|
||||
# Exponential moving average of squared gradient values
|
||||
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
state['exp_avg_sq'] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format)
|
||||
if amsgrad:
|
||||
# Maintains max of all exp. moving avg. of sq. grad. values
|
||||
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
state['max_exp_avg_sq'] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format)
|
||||
# Exponential moving average of parameter values
|
||||
state['param_exp_avg'] = p.detach().float().clone()
|
||||
|
||||
@@ -255,22 +273,25 @@ class AdamWwithEMAandWings(optim.Optimizer):
|
||||
# record the step after step update
|
||||
state_steps.append(state['step'])
|
||||
|
||||
optim._functional.adamw(params_with_grad,
|
||||
grads,
|
||||
exp_avgs,
|
||||
exp_avg_sqs,
|
||||
max_exp_avg_sqs,
|
||||
state_steps,
|
||||
amsgrad=amsgrad,
|
||||
beta1=beta1,
|
||||
beta2=beta2,
|
||||
lr=group['lr'],
|
||||
weight_decay=group['weight_decay'],
|
||||
eps=group['eps'],
|
||||
maximize=False)
|
||||
optim._functional.adamw(
|
||||
params_with_grad,
|
||||
grads,
|
||||
exp_avgs,
|
||||
exp_avg_sqs,
|
||||
max_exp_avg_sqs,
|
||||
state_steps,
|
||||
amsgrad=amsgrad,
|
||||
beta1=beta1,
|
||||
beta2=beta2,
|
||||
lr=group['lr'],
|
||||
weight_decay=group['weight_decay'],
|
||||
eps=group['eps'],
|
||||
maximize=False)
|
||||
|
||||
cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
|
||||
for param, ema_param in zip(params_with_grad, ema_params_with_grad):
|
||||
ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
|
||||
cur_ema_decay = min(ema_decay, 1 - state['step']**-ema_power)
|
||||
for param, ema_param in zip(params_with_grad,
|
||||
ema_params_with_grad):
|
||||
ema_param.mul_(cur_ema_decay).add_(
|
||||
param.float(), alpha=1 - cur_ema_decay)
|
||||
|
||||
return loss
|
||||
return loss
|
||||
|
||||
@@ -1,28 +1,27 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os.path as osp
|
||||
from typing import Any, Dict
|
||||
import rembg
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import PIL
|
||||
import rembg
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms as T
|
||||
import torchvision.transforms.functional as TF
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image
|
||||
from torchvision.utils import save_image
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
# import modelscope.models.cv.image_to_image_generation.data as data
|
||||
# import modelscope.models.cv.image_to_image_generation.models as models
|
||||
# import modelscope.models.cv.image_to_image_generation.ops as ops
|
||||
from modelscope.metainfo import Pipelines
|
||||
# from modelscope.models.cv.image_to_3d.model import UNet
|
||||
# from modelscope.models.cv.image_to_image_generation.models.clip import \
|
||||
# VisionTransformer
|
||||
|
||||
from modelscope.models.cv.image_to_3d.ldm.models.diffusion.sync_dreamer import SyncMultiviewDiffusion
|
||||
from modelscope.models.cv.image_to_3d.ldm.util import instantiate_from_config, add_margin
|
||||
|
||||
from modelscope.models.cv.image_to_3d.ldm.models.diffusion.sync_dreamer import \
|
||||
SyncMultiviewDiffusion
|
||||
from modelscope.models.cv.image_to_3d.ldm.util import (add_margin,
|
||||
instantiate_from_config)
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
@@ -31,23 +30,29 @@ from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
# from modelscope.models.cv.image_to_3d.model import UNet
|
||||
# from modelscope.models.cv.image_to_image_generation.models.clip import \
|
||||
# VisionTransformer
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
# Load Syncdreamer Model
|
||||
def load_model(cfg, ckpt, strict=True):
|
||||
config = OmegaConf.load(cfg)
|
||||
model = instantiate_from_config(config.model)
|
||||
print(f'loading model from {ckpt} ...')
|
||||
ckpt = torch.load(ckpt,map_location='cpu')
|
||||
model.load_state_dict(ckpt['state_dict'],strict=strict)
|
||||
ckpt = torch.load(ckpt, map_location='cpu')
|
||||
model.load_state_dict(ckpt['state_dict'], strict=strict)
|
||||
model = model.cuda().eval()
|
||||
return model
|
||||
|
||||
|
||||
# Prepare Syncdreamer Input
|
||||
def prepare_inputs(image_input, elevation_input, crop_size=-1, image_size=256):
|
||||
image_input[:,:,:3] = image_input[:,:,:3][:,:,::-1]
|
||||
image_input[:, :, :3] = image_input[:, :, :3][:, :, ::-1]
|
||||
image_input = Image.fromarray(image_input)
|
||||
if crop_size!=-1:
|
||||
if crop_size != -1:
|
||||
alpha_np = np.asarray(image_input)[:, :, 3]
|
||||
coords = np.stack(np.nonzero(alpha_np), 1)[:, (1, 0)]
|
||||
min_x, min_y = np.min(coords, 0)
|
||||
@@ -59,21 +64,26 @@ def prepare_inputs(image_input, elevation_input, crop_size=-1, image_size=256):
|
||||
ref_img_ = ref_img_.resize((w_, h_), resample=Image.BICUBIC)
|
||||
image_input = add_margin(ref_img_, size=image_size)
|
||||
else:
|
||||
image_input = add_margin(image_input, size=max(image_input.height, image_input.width))
|
||||
image_input = image_input.resize((image_size, image_size), resample=Image.BICUBIC)
|
||||
image_input = add_margin(
|
||||
image_input, size=max(image_input.height, image_input.width))
|
||||
image_input = image_input.resize((image_size, image_size),
|
||||
resample=Image.BICUBIC)
|
||||
|
||||
image_input = np.asarray(image_input)
|
||||
image_input = image_input.astype(np.float32) / 255.0
|
||||
ref_mask = image_input[:, :, 3:]
|
||||
image_input[:, :, :3] = image_input[:, :, :3] * ref_mask + 1 - ref_mask # white background
|
||||
image_input[:, :, :
|
||||
3] = image_input[:, :, :
|
||||
3] * ref_mask + 1 - ref_mask # white background
|
||||
image_input = image_input[:, :, :3] * 2.0 - 1.0
|
||||
image_input = torch.from_numpy(image_input.astype(np.float32))
|
||||
elevation_input = torch.from_numpy(np.asarray([np.deg2rad(elevation_input)], np.float32))
|
||||
return {"input_image": image_input, "input_elevation": elevation_input}
|
||||
elevation_input = torch.from_numpy(
|
||||
np.asarray([np.deg2rad(elevation_input)], np.float32))
|
||||
return {'input_image': image_input, 'input_elevation': elevation_input}
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.image_to_3d,
|
||||
module_name=Pipelines.image_to_3d)
|
||||
Tasks.image_to_3d, module_name=Pipelines.image_to_3d)
|
||||
class Image23DPipeline(Pipeline):
|
||||
|
||||
def __init__(self, model: str, **kwargs):
|
||||
@@ -91,23 +101,28 @@ class Image23DPipeline(Pipeline):
|
||||
self._device = torch.device('cuda')
|
||||
else:
|
||||
self._device = torch.device('cpu')
|
||||
ckpt = config_path.replace("configuration.json", "syncdreamer-pretrain.ckpt")
|
||||
self.model = load_model(config_path.replace("configuration.json", "syncdreamer.yaml"), ckpt).to(self._device)
|
||||
ckpt = config_path.replace('configuration.json',
|
||||
'syncdreamer-pretrain.ckpt')
|
||||
self.model = load_model(
|
||||
config_path.replace('configuration.json', 'syncdreamer.yaml'),
|
||||
ckpt).to(self._device)
|
||||
# os.system("pip install -r {}".format(config_path.replace("configuration.json", "requirements.txt")))
|
||||
# assert isinstance(self.model, SyncMultiviewDiffusion)
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
|
||||
|
||||
result = rembg.remove(Image.open(input))
|
||||
print(type(result))
|
||||
img = np.array(result)
|
||||
img[:,:,:3] = img[:,:,:3][:,:,::-1]
|
||||
img[:, :, :3] = img[:, :, :3][:, :, ::-1]
|
||||
# img = cv2.imread(input)
|
||||
data = prepare_inputs(img, elevation_input=10, crop_size=200, image_size=256)
|
||||
|
||||
for k,v in data.items():
|
||||
data = prepare_inputs(
|
||||
img, elevation_input=10, crop_size=200, image_size=256)
|
||||
|
||||
for k, v in data.items():
|
||||
data[k] = v.unsqueeze(0).cuda()
|
||||
data[k] = torch.repeat_interleave(data[k], 1, dim=0) # only one sample
|
||||
data[k] = torch.repeat_interleave(
|
||||
data[k], 1, dim=0) # only one sample
|
||||
return data
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -115,11 +130,11 @@ class Image23DPipeline(Pipeline):
|
||||
x_sample = self.model.sample(input, 2.0, 8)
|
||||
|
||||
B, N, _, H, W = x_sample.shape
|
||||
x_sample = (torch.clamp(x_sample,max=1.0,min=-1.0) + 1) * 0.5
|
||||
x_sample = x_sample.permute(0,1,3,4,2).cpu().numpy() * 255
|
||||
x_sample = (torch.clamp(x_sample, max=1.0, min=-1.0) + 1) * 0.5
|
||||
x_sample = x_sample.permute(0, 1, 3, 4, 2).cpu().numpy() * 255
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
show_in_im2 = [Image.fromarray(x_sample[0,ni]) for ni in range(N)]
|
||||
return {'MViews':show_in_im2}
|
||||
show_in_im2 = [Image.fromarray(x_sample[0, ni]) for ni in range(N)]
|
||||
return {'MViews': show_in_im2}
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return inputs
|
||||
|
||||
@@ -3,6 +3,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
@@ -24,11 +25,11 @@ class ImageTo3DTest(unittest.TestCase):
|
||||
def pipeline_inference(self, pipeline: Pipeline, input: str):
|
||||
result = pipeline(input['input_path'])
|
||||
np_content = []
|
||||
for idx,img in enumerate(result['MViews']):
|
||||
for idx, img in enumerate(result['MViews']):
|
||||
np_content.append(np.array(result['MViews'][idx]))
|
||||
|
||||
np_content = np.concatenate(np_content, axis=1)
|
||||
Image.fromarray(np_content).save("./concat.png")
|
||||
Image.fromarray(np_content).save('./concat.png')
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_modelhub(self):
|
||||
@@ -38,4 +39,4 @@ class ImageTo3DTest(unittest.TestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user