fix flake8

This commit is contained in:
ly119399
2024-01-03 00:05:04 +08:00
parent 19cb79ae6c
commit ec07a9919a
25 changed files with 3082 additions and 1857 deletions

View File

@@ -1,2 +1,2 @@
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
from . import ldm from . import ldm

View File

@@ -1,6 +1,7 @@
import pickle import pickle
import numpy as np
import cv2 import cv2
import numpy as np
from skimage.io import imread from skimage.io import imread
@@ -9,16 +10,18 @@ def save_pickle(data, pkl_path):
with open(pkl_path, 'wb') as f: with open(pkl_path, 'wb') as f:
pickle.dump(data, f) pickle.dump(data, f)
def read_pickle(pkl_path): def read_pickle(pkl_path):
with open(pkl_path, 'rb') as f: with open(pkl_path, 'rb') as f:
return pickle.load(f) return pickle.load(f)
def draw_epipolar_line(F, img0, img1, pt0, color): def draw_epipolar_line(F, img0, img1, pt0, color):
h1,w1=img1.shape[:2] h1, w1 = img1.shape[:2]
hpt = np.asarray([pt0[0], pt0[1], 1], dtype=np.float32)[:, None] hpt = np.asarray([pt0[0], pt0[1], 1], dtype=np.float32)[:, None]
l = F @ hpt _l = F @ hpt
l = l[:, 0] _l = _l[:, 0]
a, b, c = l[0], l[1], l[2] a, b, c = _l[0], _l[1], _l[2]
pt1 = np.asarray([0, -c / b]).astype(np.int32) pt1 = np.asarray([0, -c / b]).astype(np.int32)
pt2 = np.asarray([w1, (-a * w1 - c) / b]).astype(np.int32) pt2 = np.asarray([w1, (-a * w1 - c) / b]).astype(np.int32)
@@ -26,8 +29,9 @@ def draw_epipolar_line(F, img0, img1, pt0, color):
img1 = cv2.line(img1, tuple(pt1), tuple(pt2), color, 2) img1 = cv2.line(img1, tuple(pt1), tuple(pt2), color, 2)
return img0, img1 return img0, img1
def draw_epipolar_lines(F, img0, img1,num=20):
img0,img1=img0.copy(),img1.copy() def draw_epipolar_lines(F, img0, img1, num=20):
img0, img1 = img0.copy(), img1.copy()
h0, w0, _ = img0.shape h0, w0, _ = img0.shape
h1, w1, _ = img1.shape h1, w1, _ = img1.shape
@@ -42,117 +46,166 @@ def draw_epipolar_lines(F, img0, img1,num=20):
return img0, img1 return img0, img1
def compute_F(K1, K2, Rt0, Rt1=None): def compute_F(K1, K2, Rt0, Rt1=None):
if Rt1 is None: if Rt1 is None:
R, t = Rt0[:,:3], Rt0[:,3:] R, t = Rt0[:, :3], Rt0[:, 3:]
else: else:
Rt = compute_dR_dt(Rt0,Rt1) Rt = compute_dR_dt(Rt0, Rt1)
R, t = Rt[:,:3], Rt[:,3:] R, t = Rt[:, :3], Rt[:, 3:]
A = K1 @ R.T @ t # [3,1] A = K1 @ R.T @ t # [3,1]
C = np.asarray([[0,-A[2,0],A[1,0]], C = np.asarray([[0, -A[2, 0], A[1, 0]], [A[2, 0], 0, -A[0, 0]],
[A[2,0],0,-A[0,0]], [-A[1, 0], A[0, 0], 0]])
[-A[1,0],A[0,0],0]])
F = (np.linalg.inv(K2)).T @ R @ K1.T @ C F = (np.linalg.inv(K2)).T @ R @ K1.T @ C
return F return F
def compute_dR_dt(Rt0, Rt1): def compute_dR_dt(Rt0, Rt1):
R0, t0 = Rt0[:,:3], Rt0[:,3:] R0, t0 = Rt0[:, :3], Rt0[:, 3:]
R1, t1 = Rt1[:,:3], Rt1[:,3:] R1, t1 = Rt1[:, :3], Rt1[:, 3:]
dR = np.dot(R1, R0.T) dR = np.dot(R1, R0.T)
dt = t1 - np.dot(dR, t0) dt = t1 - np.dot(dR, t0)
return np.concatenate([dR, dt], -1) return np.concatenate([dR, dt], -1)
def concat_images(img0,img1,vert=False):
def concat_images(img0, img1, vert=False):
if not vert: if not vert:
h0,h1=img0.shape[0],img1.shape[0], h0, h1 = img0.shape[0], img1.shape[0],
if h0<h1: img0=cv2.copyMakeBorder(img0,0,h1-h0,0,0,borderType=cv2.BORDER_CONSTANT,value=0) if h0 < h1:
if h1<h0: img1=cv2.copyMakeBorder(img1,0,h0-h1,0,0,borderType=cv2.BORDER_CONSTANT,value=0) img0 = cv2.copyMakeBorder(
img0,
0,
h1 - h0,
0,
0,
borderType=cv2.BORDER_CONSTANT,
value=0)
if h1 < h0:
img1 = cv2.copyMakeBorder(
img1,
0,
h0 - h1,
0,
0,
borderType=cv2.BORDER_CONSTANT,
value=0)
img = np.concatenate([img0, img1], axis=1) img = np.concatenate([img0, img1], axis=1)
else: else:
w0,w1=img0.shape[1],img1.shape[1] w0, w1 = img0.shape[1], img1.shape[1]
if w0<w1: img0=cv2.copyMakeBorder(img0,0,0,0,w1-w0,borderType=cv2.BORDER_CONSTANT,value=0) if w0 < w1:
if w1<w0: img1=cv2.copyMakeBorder(img1,0,0,0,w0-w1,borderType=cv2.BORDER_CONSTANT,value=0) img0 = cv2.copyMakeBorder(
img0,
0,
0,
0,
w1 - w0,
borderType=cv2.BORDER_CONSTANT,
value=0)
if w1 < w0:
img1 = cv2.copyMakeBorder(
img1,
0,
0,
0,
w0 - w1,
borderType=cv2.BORDER_CONSTANT,
value=0)
img = np.concatenate([img0, img1], axis=0) img = np.concatenate([img0, img1], axis=0)
return img return img
def concat_images_list(*args,vert=False):
if len(args)==1: return args[0] def concat_images_list(*args, vert=False):
img_out=args[0] if len(args) == 1:
return args[0]
img_out = args[0]
for img in args[1:]: for img in args[1:]:
img_out=concat_images(img_out,img,vert) img_out = concat_images(img_out, img, vert)
return img_out return img_out
def pose_inverse(pose): def pose_inverse(pose):
R = pose[:,:3].T R = pose[:, :3].T
t = - R @ pose[:,3:] t = -R @ pose[:, 3:]
return np.concatenate([R,t],-1) return np.concatenate([R, t], -1)
def project_points(pts,RT,K):
pts = np.matmul(pts,RT[:,:3].transpose())+RT[:,3:].transpose() def project_points(pts, RT, K):
pts = np.matmul(pts,K.transpose()) pts = np.matmul(pts, RT[:, :3].transpose()) + RT[:, 3:].transpose()
dpt = pts[:,2] pts = np.matmul(pts, K.transpose())
mask0 = (np.abs(dpt)<1e-4) & (np.abs(dpt)>0) dpt = pts[:, 2]
if np.sum(mask0)>0: dpt[mask0]=1e-4 mask0 = (np.abs(dpt) < 1e-4) & (np.abs(dpt) > 0)
mask1=(np.abs(dpt) > -1e-4) & (np.abs(dpt) < 0) if np.sum(mask0) > 0:
if np.sum(mask1)>0: dpt[mask1]=-1e-4 dpt[mask0] = 1e-4
pts2d = pts[:,:2]/dpt[:,None] mask1 = (np.abs(dpt) > -1e-4) & (np.abs(dpt) < 0)
if np.sum(mask1) > 0:
dpt[mask1] = -1e-4
pts2d = pts[:, :2] / dpt[:, None]
return pts2d, dpt return pts2d, dpt
def draw_keypoints(img, kps, colors=None, radius=2): def draw_keypoints(img, kps, colors=None, radius=2):
out_img=img.copy() out_img = img.copy()
for pi, pt in enumerate(kps): for pi, pt in enumerate(kps):
pt = np.round(pt).astype(np.int32) pt = np.round(pt).astype(np.int32)
if colors is not None: if colors is not None:
color=[int(c) for c in colors[pi]] color = [int(c) for c in colors[pi]]
cv2.circle(out_img, tuple(pt), radius, color, -1) cv2.circle(out_img, tuple(pt), radius, color, -1)
else: else:
cv2.circle(out_img, tuple(pt), radius, (0,255,0), -1) cv2.circle(out_img, tuple(pt), radius, (0, 255, 0), -1)
return out_img return out_img
def output_points(fn,pts,colors=None): def output_points(fn, pts, colors=None):
with open(fn, 'w') as f: with open(fn, 'w') as f:
for pi, pt in enumerate(pts): for pi, pt in enumerate(pts):
f.write(f'{pt[0]:.6f} {pt[1]:.6f} {pt[2]:.6f} ') f.write(f'{pt[0]:.6f} {pt[1]:.6f} {pt[2]:.6f} ')
if colors is not None: if colors is not None:
f.write(f'{int(colors[pi,0])} {int(colors[pi,1])} {int(colors[pi,2])}') f.write(
f'{int(colors[pi, 0])} {int(colors[pi, 1])} {int(colors[pi, 2])}'
)
f.write('\n') f.write('\n')
DEPTH_MAX, DEPTH_MIN = 2.4, 0.6 DEPTH_MAX, DEPTH_MIN = 2.4, 0.6
DEPTH_VALID_MAX, DEPTH_VALID_MIN = 2.37, 0.63 DEPTH_VALID_MAX, DEPTH_VALID_MIN = 2.37, 0.63
def read_depth_objaverse(depth_fn): def read_depth_objaverse(depth_fn):
depth = imread(depth_fn) depth = imread(depth_fn)
depth = depth.astype(np.float32) / 65535 * (DEPTH_MAX-DEPTH_MIN) + DEPTH_MIN depth = depth.astype(
np.float32) / 65535 * (DEPTH_MAX - DEPTH_MIN) + DEPTH_MIN
mask = (depth > DEPTH_VALID_MIN) & (depth < DEPTH_VALID_MAX) mask = (depth > DEPTH_VALID_MIN) & (depth < DEPTH_VALID_MAX)
return depth, mask return depth, mask
def mask_depth_to_pts(mask,depth,K,rgb=None): def mask_depth_to_pts(mask, depth, K, rgb=None):
hs,ws=np.nonzero(mask) hs, ws = np.nonzero(mask)
depth=depth[hs,ws] depth = depth[hs, ws]
pts=np.asarray([ws,hs,depth],np.float32).transpose() pts = np.asarray([ws, hs, depth], np.float32).transpose()
pts[:,:2]*=pts[:,2:] pts[:, :2] *= pts[:, 2:]
if rgb is not None: if rgb is not None:
return np.dot(pts, np.linalg.inv(K).transpose()), rgb[hs,ws] return np.dot(pts, np.linalg.inv(K).transpose()), rgb[hs, ws]
else: else:
return np.dot(pts, np.linalg.inv(K).transpose()) return np.dot(pts, np.linalg.inv(K).transpose())
def transform_points_pose(pts, pose): def transform_points_pose(pts, pose):
R, t = pose[:, :3], pose[:, 3] R, t = pose[:, :3], pose[:, 3]
if len(pts.shape)==1: if len(pts.shape) == 1:
return (R @ pts[:,None] + t[:,None])[:,0] return (R @ pts[:, None] + t[:, None])[:, 0]
return pts @ R.T + t[None,:] return pts @ R.T + t[None, :]
def pose_apply(pose,pts):
def pose_apply(pose, pts):
return transform_points_pose(pts, pose) return transform_points_pose(pts, pose)
def downsample_gaussian_blur(img, ratio): def downsample_gaussian_blur(img, ratio):
sigma = (1 / ratio) / 3 sigma = (1 / ratio) / 3
# ksize=np.ceil(2*sigma) # ksize=np.ceil(2*sigma)
ksize = int(np.ceil(((sigma - 0.8) / 0.3 + 1) * 2 + 1)) ksize = int(np.ceil(((sigma - 0.8) / 0.3 + 1) * 2 + 1))
ksize = ksize + 1 if ksize % 2 == 0 else ksize ksize = ksize + 1 if ksize % 2 == 0 else ksize
img = cv2.GaussianBlur(img, (ksize, ksize), sigma, borderType=cv2.BORDER_REFLECT101) img = cv2.GaussianBlur(
return img img, (ksize, ksize), sigma, borderType=cv2.BORDER_REFLECT101)
return img

View File

@@ -1,34 +1,36 @@
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from contextlib import contextmanager from contextlib import contextmanager
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.model import Encoder, Decoder from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.model import (
from modelscope.models.cv.image_to_3d.ldm.modules.distributions.distributions import DiagonalGaussianDistribution Decoder, Encoder)
from modelscope.models.cv.image_to_3d.ldm.modules.distributions.distributions import \
DiagonalGaussianDistribution
from modelscope.models.cv.image_to_3d.ldm.util import instantiate_from_config from modelscope.models.cv.image_to_3d.ldm.util import instantiate_from_config
class VQModel(pl.LightningModule): class VQModel(pl.LightningModule):
def __init__(self,
ddconfig, def __init__(
lossconfig, self,
n_embed, ddconfig,
embed_dim, lossconfig,
ckpt_path=None, n_embed,
ignore_keys=[], embed_dim,
image_key="image", ckpt_path=None,
colorize_nlabels=None, ignore_keys=[],
monitor=None, image_key='image',
batch_resize_range=None, colorize_nlabels=None,
scheduler_config=None, monitor=None,
lr_g_factor=1.0, batch_resize_range=None,
remap=None, scheduler_config=None,
sane_index_shape=False, # tell vector quantizer to return indices as bhw lr_g_factor=1.0,
use_ema=False remap=None,
): sane_index_shape=False, # tell vector quantizer to return indices as bhw
use_ema=False):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.n_embed = n_embed self.n_embed = n_embed
@@ -36,24 +38,31 @@ class VQModel(pl.LightningModule):
self.encoder = Encoder(**ddconfig) self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig) self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig) self.loss = instantiate_from_config(lossconfig)
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, self.quantize = VectorQuantizer(
remap=remap, n_embed,
sane_index_shape=sane_index_shape) embed_dim,
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) beta=0.25,
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) remap=remap,
sane_index_shape=sane_index_shape)
self.quant_conv = torch.nn.Conv2d(ddconfig['z_channels'], embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim,
ddconfig['z_channels'], 1)
if colorize_nlabels is not None: if colorize_nlabels is not None:
assert type(colorize_nlabels)==int assert type(colorize_nlabels) == int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) self.register_buffer('colorize',
torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None: if monitor is not None:
self.monitor = monitor self.monitor = monitor
self.batch_resize_range = batch_resize_range self.batch_resize_range = batch_resize_range
if self.batch_resize_range is not None: if self.batch_resize_range is not None:
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.") print(
f'{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.'
)
self.use_ema = use_ema self.use_ema = use_ema
if self.use_ema: if self.use_ema:
self.model_ema = LitEma(self) self.model_ema = LitEma(self)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
if ckpt_path is not None: if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
@@ -66,28 +75,30 @@ class VQModel(pl.LightningModule):
self.model_ema.store(self.parameters()) self.model_ema.store(self.parameters())
self.model_ema.copy_to(self) self.model_ema.copy_to(self)
if context is not None: if context is not None:
print(f"{context}: Switched to EMA weights") print(f'{context}: Switched to EMA weights')
try: try:
yield None yield None
finally: finally:
if self.use_ema: if self.use_ema:
self.model_ema.restore(self.parameters()) self.model_ema.restore(self.parameters())
if context is not None: if context is not None:
print(f"{context}: Restored training weights") print(f'{context}: Restored training weights')
def init_from_ckpt(self, path, ignore_keys=list()): def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"] sd = torch.load(path, map_location='cpu')['state_dict']
keys = list(sd.keys()) keys = list(sd.keys())
for k in keys: for k in keys:
for ik in ignore_keys: for ik in ignore_keys:
if k.startswith(ik): if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k)) print('Deleting key {} from state_dict.'.format(k))
del sd[k] del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False) missing, unexpected = self.load_state_dict(sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") print(
f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys'
)
if len(missing) > 0: if len(missing) > 0:
print(f"Missing Keys: {missing}") print(f'Missing Keys: {missing}')
print(f"Unexpected Keys: {unexpected}") print(f'Unexpected Keys: {unexpected}')
def on_train_batch_end(self, *args, **kwargs): def on_train_batch_end(self, *args, **kwargs):
if self.use_ema: if self.use_ema:
@@ -115,7 +126,7 @@ class VQModel(pl.LightningModule):
return dec return dec
def forward(self, input, return_pred_indices=False): def forward(self, input, return_pred_indices=False):
quant, diff, (_,_,ind) = self.encode(input) quant, diff, (_, _, ind) = self.encode(input)
dec = self.decode(quant) dec = self.decode(quant)
if return_pred_indices: if return_pred_indices:
return dec, diff, ind return dec, diff, ind
@@ -125,7 +136,8 @@ class VQModel(pl.LightningModule):
x = batch[k] x = batch[k]
if len(x.shape) == 3: if len(x.shape) == 3:
x = x[..., None] x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() x = x.permute(0, 3, 1,
2).to(memory_format=torch.contiguous_format).float()
if self.batch_resize_range is not None: if self.batch_resize_range is not None:
lower_size = self.batch_resize_range[0] lower_size = self.batch_resize_range[0]
upper_size = self.batch_resize_range[1] upper_size = self.batch_resize_range[1]
@@ -133,9 +145,10 @@ class VQModel(pl.LightningModule):
# do the first few batches with max size to avoid later oom # do the first few batches with max size to avoid later oom
new_resize = upper_size new_resize = upper_size
else: else:
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16)) new_resize = np.random.choice(
np.arange(lower_size, upper_size + 16, 16))
if new_resize != x.shape[2]: if new_resize != x.shape[2]:
x = F.interpolate(x, size=new_resize, mode="bicubic") x = F.interpolate(x, size=new_resize, mode='bicubic')
x = x.detach() x = x.detach()
return x return x
@@ -147,79 +160,122 @@ class VQModel(pl.LightningModule):
if optimizer_idx == 0: if optimizer_idx == 0:
# autoencode # autoencode
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, aeloss, log_dict_ae = self.loss(
last_layer=self.get_last_layer(), split="train", qloss,
predicted_indices=ind) x,
xrec,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split='train',
predicted_indices=ind)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) self.log_dict(
log_dict_ae,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=True)
return aeloss return aeloss
if optimizer_idx == 1: if optimizer_idx == 1:
# discriminator # discriminator
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, discloss, log_dict_disc = self.loss(
last_layer=self.get_last_layer(), split="train") qloss,
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) x,
xrec,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split='train')
self.log_dict(
log_dict_disc,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=True)
return discloss return discloss
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx) log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope(): with self.ema_scope():
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") self._validation_step(batch, batch_idx, suffix='_ema')
return log_dict return log_dict
def _validation_step(self, batch, batch_idx, suffix=""): def _validation_step(self, batch, batch_idx, suffix=''):
x = self.get_input(batch, self.image_key) x = self.get_input(batch, self.image_key)
xrec, qloss, ind = self(x, return_pred_indices=True) xrec, qloss, ind = self(x, return_pred_indices=True)
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, aeloss, log_dict_ae = self.loss(
self.global_step, qloss,
last_layer=self.get_last_layer(), x,
split="val"+suffix, xrec,
predicted_indices=ind 0,
) self.global_step,
last_layer=self.get_last_layer(),
split='val' + suffix,
predicted_indices=ind)
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, discloss, log_dict_disc = self.loss(
self.global_step, qloss,
last_layer=self.get_last_layer(), x,
split="val"+suffix, xrec,
predicted_indices=ind 1,
) self.global_step,
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"] last_layer=self.get_last_layer(),
self.log(f"val{suffix}/rec_loss", rec_loss, split='val' + suffix,
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) predicted_indices=ind)
self.log(f"val{suffix}/aeloss", aeloss, rec_loss = log_dict_ae[f'val{suffix}/rec_loss']
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) self.log(
f'val{suffix}/rec_loss',
rec_loss,
prog_bar=True,
logger=True,
on_step=False,
on_epoch=True,
sync_dist=True)
self.log(
f'val{suffix}/aeloss',
aeloss,
prog_bar=True,
logger=True,
on_step=False,
on_epoch=True,
sync_dist=True)
if version.parse(pl.__version__) >= version.parse('1.4.0'): if version.parse(pl.__version__) >= version.parse('1.4.0'):
del log_dict_ae[f"val{suffix}/rec_loss"] del log_dict_ae[f'val{suffix}/rec_loss']
self.log_dict(log_dict_ae) self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc) self.log_dict(log_dict_disc)
return self.log_dict return self.log_dict
def configure_optimizers(self): def configure_optimizers(self):
lr_d = self.learning_rate lr_d = self.learning_rate
lr_g = self.lr_g_factor*self.learning_rate lr_g = self.lr_g_factor * self.learning_rate
print("lr_d", lr_d) print('lr_d', lr_d)
print("lr_g", lr_g) print('lr_g', lr_g)
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ opt_ae = torch.optim.Adam(
list(self.decoder.parameters())+ list(self.encoder.parameters()) + list(self.decoder.parameters())
list(self.quantize.parameters())+ + list(self.quantize.parameters())
list(self.quant_conv.parameters())+ + list(self.quant_conv.parameters())
list(self.post_quant_conv.parameters()), + list(self.post_quant_conv.parameters()),
lr=lr_g, betas=(0.5, 0.9)) lr=lr_g,
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), betas=(0.5, 0.9))
lr=lr_d, betas=(0.5, 0.9)) opt_disc = torch.optim.Adam(
self.loss.discriminator.parameters(), lr=lr_d, betas=(0.5, 0.9))
if self.scheduler_config is not None: if self.scheduler_config is not None:
scheduler = instantiate_from_config(self.scheduler_config) scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...") print('Setting up LambdaLR scheduler...')
scheduler = [ scheduler = [
{ {
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule), 'scheduler':
LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
'interval': 'step', 'interval': 'step',
'frequency': 1 'frequency': 1
}, },
{ {
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule), 'scheduler':
LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
'interval': 'step', 'interval': 'step',
'frequency': 1 'frequency': 1
}, },
@@ -235,7 +291,7 @@ class VQModel(pl.LightningModule):
x = self.get_input(batch, self.image_key) x = self.get_input(batch, self.image_key)
x = x.to(self.device) x = x.to(self.device)
if only_inputs: if only_inputs:
log["inputs"] = x log['inputs'] = x
return log return log
xrec, _ = self(x) xrec, _ = self(x)
if x.shape[1] > 3: if x.shape[1] > 3:
@@ -243,25 +299,28 @@ class VQModel(pl.LightningModule):
assert xrec.shape[1] > 3 assert xrec.shape[1] > 3
x = self.to_rgb(x) x = self.to_rgb(x)
xrec = self.to_rgb(xrec) xrec = self.to_rgb(xrec)
log["inputs"] = x log['inputs'] = x
log["reconstructions"] = xrec log['reconstructions'] = xrec
if plot_ema: if plot_ema:
with self.ema_scope(): with self.ema_scope():
xrec_ema, _ = self(x) xrec_ema, _ = self(x)
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) if x.shape[1] > 3:
log["reconstructions_ema"] = xrec_ema xrec_ema = self.to_rgb(xrec_ema)
log['reconstructions_ema'] = xrec_ema
return log return log
def to_rgb(self, x): def to_rgb(self, x):
assert self.image_key == "segmentation" assert self.image_key == 'segmentation'
if not hasattr(self, "colorize"): if not hasattr(self, 'colorize'):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) self.register_buffer('colorize',
torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize) x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1. x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
return x return x
class VQModelInterface(VQModel): class VQModelInterface(VQModel):
def __init__(self, embed_dim, *args, **kwargs): def __init__(self, embed_dim, *args, **kwargs):
super().__init__(embed_dim=embed_dim, *args, **kwargs) super().__init__(embed_dim=embed_dim, *args, **kwargs)
self.embed_dim = embed_dim self.embed_dim = embed_dim
@@ -283,43 +342,48 @@ class VQModelInterface(VQModel):
class AutoencoderKL(pl.LightningModule): class AutoencoderKL(pl.LightningModule):
def __init__(self,
ddconfig, def __init__(
lossconfig, self,
embed_dim, ddconfig,
ckpt_path=None, lossconfig,
ignore_keys=[], embed_dim,
image_key="image", ckpt_path=None,
colorize_nlabels=None, ignore_keys=[],
monitor=None, image_key='image',
): colorize_nlabels=None,
monitor=None,
):
super().__init__() super().__init__()
self.image_key = image_key self.image_key = image_key
self.encoder = Encoder(**ddconfig) self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig) self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig) self.loss = instantiate_from_config(lossconfig)
assert ddconfig["double_z"] assert ddconfig['double_z']
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) self.quant_conv = torch.nn.Conv2d(2 * ddconfig['z_channels'],
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 2 * embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim,
ddconfig['z_channels'], 1)
self.embed_dim = embed_dim self.embed_dim = embed_dim
if colorize_nlabels is not None: if colorize_nlabels is not None:
assert type(colorize_nlabels)==int assert type(colorize_nlabels) == int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) self.register_buffer('colorize',
torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None: if monitor is not None:
self.monitor = monitor self.monitor = monitor
if ckpt_path is not None: if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list()): def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"] sd = torch.load(path, map_location='cpu')['state_dict']
keys = list(sd.keys()) keys = list(sd.keys())
for k in keys: for k in keys:
for ik in ignore_keys: for ik in ignore_keys:
if k.startswith(ik): if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k)) print('Deleting key {} from state_dict.'.format(k))
del sd[k] del sd[k]
self.load_state_dict(sd, strict=False) self.load_state_dict(sd, strict=False)
print(f"Restored from {path}") print(f'Restored from {path}')
def encode(self, x): def encode(self, x):
h = self.encoder(x) h = self.encoder(x)
@@ -345,7 +409,8 @@ class AutoencoderKL(pl.LightningModule):
x = batch[k] x = batch[k]
if len(x.shape) == 3: if len(x.shape) == 3:
x = x[..., None] x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() x = x.permute(0, 3, 1,
2).to(memory_format=torch.contiguous_format).float()
return x return x
def training_step(self, batch, batch_idx, optimizer_idx): def training_step(self, batch, batch_idx, optimizer_idx):
@@ -354,44 +419,91 @@ class AutoencoderKL(pl.LightningModule):
if optimizer_idx == 0: if optimizer_idx == 0:
# train encoder+decoder+logvar # train encoder+decoder+logvar
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, aeloss, log_dict_ae = self.loss(
last_layer=self.get_last_layer(), split="train") inputs,
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) reconstructions,
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) posterior,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split='train')
self.log(
'aeloss',
aeloss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True)
self.log_dict(
log_dict_ae,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=False)
return aeloss return aeloss
if optimizer_idx == 1: if optimizer_idx == 1:
# train the discriminator # train the discriminator
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, discloss, log_dict_disc = self.loss(
last_layer=self.get_last_layer(), split="train") inputs,
reconstructions,
posterior,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split='train')
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) self.log(
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) 'discloss',
discloss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True)
self.log_dict(
log_dict_disc,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=False)
return discloss return discloss
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
inputs = self.get_input(batch, self.image_key) inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs) reconstructions, posterior = self(inputs)
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, aeloss, log_dict_ae = self.loss(
last_layer=self.get_last_layer(), split="val") inputs,
reconstructions,
posterior,
0,
self.global_step,
last_layer=self.get_last_layer(),
split='val')
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, discloss, log_dict_disc = self.loss(
last_layer=self.get_last_layer(), split="val") inputs,
reconstructions,
posterior,
1,
self.global_step,
last_layer=self.get_last_layer(),
split='val')
self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) self.log('val/rec_loss', log_dict_ae['val/rec_loss'])
self.log_dict(log_dict_ae) self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc) self.log_dict(log_dict_disc)
return self.log_dict return self.log_dict
def configure_optimizers(self): def configure_optimizers(self):
lr = self.learning_rate lr = self.learning_rate
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ opt_ae = torch.optim.Adam(
list(self.decoder.parameters())+ list(self.encoder.parameters()) + list(self.decoder.parameters())
list(self.quant_conv.parameters())+ + list(self.quant_conv.parameters())
list(self.post_quant_conv.parameters()), + list(self.post_quant_conv.parameters()),
lr=lr, betas=(0.5, 0.9)) lr=lr,
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), betas=(0.5, 0.9))
lr=lr, betas=(0.5, 0.9)) opt_disc = torch.optim.Adam(
self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9))
return [opt_ae, opt_disc], [] return [opt_ae, opt_disc], []
def get_last_layer(self): def get_last_layer(self):
@@ -409,21 +521,23 @@ class AutoencoderKL(pl.LightningModule):
assert xrec.shape[1] > 3 assert xrec.shape[1] > 3
x = self.to_rgb(x) x = self.to_rgb(x)
xrec = self.to_rgb(xrec) xrec = self.to_rgb(xrec)
log["samples"] = self.decode(torch.randn_like(posterior.sample())) log['samples'] = self.decode(torch.randn_like(posterior.sample()))
log["reconstructions"] = xrec log['reconstructions'] = xrec
log["inputs"] = x log['inputs'] = x
return log return log
def to_rgb(self, x): def to_rgb(self, x):
assert self.image_key == "segmentation" assert self.image_key == 'segmentation'
if not hasattr(self, "colorize"): if not hasattr(self, 'colorize'):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) self.register_buffer('colorize',
torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize) x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1. x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
return x return x
class IdentityFirstStage(torch.nn.Module): class IdentityFirstStage(torch.nn.Module):
def __init__(self, *args, vq_interface=False, **kwargs): def __init__(self, *args, vq_interface=False, **kwargs):
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
super().__init__() super().__init__()

View File

@@ -1,17 +1,26 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from modelscope.models.cv.image_to_3d.ldm.modules.attention import default, zero_module, checkpoint import modelscope.models.cv.image_to_3d.ldm.modules.attention as attention
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.openaimodel import UNetModel from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.openaimodel import \
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.util import timestep_embedding UNetModel
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.util import \
timestep_embedding
class DepthAttention(nn.Module): class DepthAttention(nn.Module):
def __init__(self, query_dim, context_dim, heads, dim_head, output_bias=True):
def __init__(self,
query_dim,
context_dim,
heads,
dim_head,
output_bias=True):
super().__init__() super().__init__()
inner_dim = dim_head * heads inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim) context_dim = attention.default(context_dim, query_dim)
self.scale = dim_head ** -0.5 self.scale = dim_head**-0.5
self.heads = heads self.heads = heads
self.dim_head = dim_head self.dim_head = dim_head
@@ -34,21 +43,27 @@ class DepthAttention(nn.Module):
b, _, h, w = x.shape b, _, h, w = x.shape
b, _, d, h, w = context.shape b, _, d, h, w = context.shape
q = self.to_q(x).reshape(b,hn,hd,h,w) # b,t,h,w q = self.to_q(x).reshape(b, hn, hd, h, w) # b,t,h,w
k = self.to_k(context).reshape(b,hn,hd,d,h,w) # b,t,d,h,w k = self.to_k(context).reshape(b, hn, hd, d, h, w) # b,t,d,h,w
v = self.to_v(context).reshape(b,hn,hd,d,h,w) # b,t,d,h,w v = self.to_v(context).reshape(b, hn, hd, d, h, w) # b,t,d,h,w
sim = torch.sum(q.unsqueeze(3) * k, 2) * self.scale # b,hn,d,h,w sim = torch.sum(q.unsqueeze(3) * k, 2) * self.scale # b,hn,d,h,w
attn = sim.softmax(dim=2) attn = sim.softmax(dim=2)
# b,hn,hd,d,h,w * b,hn,1,d,h,w # b,hn,hd,d,h,w * b,hn,1,d,h,w
out = torch.sum(v * attn.unsqueeze(2), 3) # b,hn,hd,h,w out = torch.sum(v * attn.unsqueeze(2), 3) # b,hn,hd,h,w
out = out.reshape(b,hn*hd,h,w) out = out.reshape(b, hn * hd, h, w)
return self.to_out(out) return self.to_out(out)
class DepthTransformer(nn.Module): class DepthTransformer(nn.Module):
def __init__(self, dim, n_heads, d_head, context_dim=None, checkpoint=True):
def __init__(self,
dim,
n_heads,
d_head,
context_dim=None,
checkpoint=True):
super().__init__() super().__init__()
inner_dim = n_heads * d_head inner_dim = n_heads * d_head
self.proj_in = nn.Sequential( self.proj_in = nn.Sequential(
@@ -57,23 +72,33 @@ class DepthTransformer(nn.Module):
nn.SiLU(True), nn.SiLU(True),
) )
self.proj_context = nn.Sequential( self.proj_context = nn.Sequential(
nn.Conv3d(context_dim, context_dim, 1, 1, bias=False), # no bias nn.Conv3d(context_dim, context_dim, 1, 1, bias=False), # no bias
nn.GroupNorm(8, context_dim), nn.GroupNorm(8, context_dim),
nn.ReLU(True), # only relu, because we want input is 0, output is 0 nn.ReLU(
True), # only relu, because we want input is 0, output is 0
) )
self.depth_attn = DepthAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, context_dim=context_dim, output_bias=False) # is a self-attention if not self.disable_self_attn self.depth_attn = DepthAttention(
query_dim=inner_dim,
heads=n_heads,
dim_head=d_head,
context_dim=context_dim,
output_bias=False
) # is a self-attention if not self.disable_self_attn
self.proj_out = nn.Sequential( self.proj_out = nn.Sequential(
nn.GroupNorm(8, inner_dim), nn.GroupNorm(8, inner_dim),
nn.ReLU(True), nn.ReLU(True),
nn.Conv2d(inner_dim, inner_dim, 3, 1, 1, bias=False), nn.Conv2d(inner_dim, inner_dim, 3, 1, 1, bias=False),
nn.GroupNorm(8, inner_dim), nn.GroupNorm(8, inner_dim),
nn.ReLU(True), nn.ReLU(True),
zero_module(nn.Conv2d(inner_dim, dim, 3, 1, 1, bias=False)), attention.zero_module(
nn.Conv2d(inner_dim, dim, 3, 1, 1, bias=False)),
) )
self.checkpoint = checkpoint self.checkpoint = attention.checkpoint
def forward(self, x, context=None): def forward(self, x, context=None):
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
return attention.checkpoint(self._forward, (x, context),
self.parameters(), self.checkpoint) # noqa
def _forward(self, x, context): def _forward(self, x, context):
x_in = x x_in = x
@@ -85,38 +110,65 @@ class DepthTransformer(nn.Module):
class DepthWiseAttention(UNetModel): class DepthWiseAttention(UNetModel):
def __init__(self, volume_dims=(5,16,32,64), *args, **kwargs):
def __init__(self, volume_dims=(5, 16, 32, 64), *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
# num_heads = 4 # num_heads = 4
model_channels = kwargs['model_channels'] model_channels = kwargs['model_channels']
channel_mult = kwargs['channel_mult'] channel_mult = kwargs['channel_mult']
d0,d1,d2,d3 = volume_dims d0, d1, d2, d3 = volume_dims
# 4 # 4
ch = model_channels*channel_mult[2] ch = model_channels * channel_mult[2]
self.middle_conditions = DepthTransformer(ch, 4, d3 // 2, context_dim=d3) self.middle_conditions = DepthTransformer(
ch, 4, d3 // 2, context_dim=d3)
self.output_conditions=nn.ModuleList() self.output_conditions = nn.ModuleList()
self.output_b2c = {3:0,4:1,5:2,6:3,7:4,8:5,9:6,10:7,11:8} self.output_b2c = {
3: 0,
4: 1,
5: 2,
6: 3,
7: 4,
8: 5,
9: 6,
10: 7,
11: 8
}
# 8 # 8
ch = model_channels*channel_mult[2] ch = model_channels * channel_mult[2]
self.output_conditions.append(DepthTransformer(ch, 4, d2 // 2, context_dim=d2)) # 0 self.output_conditions.append(
self.output_conditions.append(DepthTransformer(ch, 4, d2 // 2, context_dim=d2)) # 1 DepthTransformer(ch, 4, d2 // 2, context_dim=d2)) # 0
self.output_conditions.append(
DepthTransformer(ch, 4, d2 // 2, context_dim=d2)) # 1
# 16 # 16
self.output_conditions.append(DepthTransformer(ch, 4, d1 // 2, context_dim=d1)) # 2 self.output_conditions.append(
ch = model_channels*channel_mult[1] DepthTransformer(ch, 4, d1 // 2, context_dim=d1)) # 2
self.output_conditions.append(DepthTransformer(ch, 4, d1 // 2, context_dim=d1)) # 3 ch = model_channels * channel_mult[1]
self.output_conditions.append(DepthTransformer(ch, 4, d1 // 2, context_dim=d1)) # 4 self.output_conditions.append(
DepthTransformer(ch, 4, d1 // 2, context_dim=d1)) # 3
self.output_conditions.append(
DepthTransformer(ch, 4, d1 // 2, context_dim=d1)) # 4
# 32 # 32
self.output_conditions.append(DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 5 self.output_conditions.append(
ch = model_channels*channel_mult[0] DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 5
self.output_conditions.append(DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 6 ch = model_channels * channel_mult[0]
self.output_conditions.append(DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 7 self.output_conditions.append(
self.output_conditions.append(DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 8 DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 6
self.output_conditions.append(
DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 7
self.output_conditions.append(
DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 8
def forward(self, x, timesteps=None, context=None, source_dict=None, **kwargs): def forward(self,
x,
timesteps=None,
context=None,
source_dict=None,
**kwargs):
hs = [] hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) t_emb = timestep_embedding(
timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb) emb = self.time_embed(t_emb)
h = x.type(self.dtype) h = x.type(self.dtype)
@@ -138,5 +190,6 @@ class DepthWiseAttention(UNetModel):
return self.out(h) return self.out(h)
def get_trainable_parameters(self): def get_trainable_parameters(self):
paras = [para for para in self.middle_conditions.parameters()] + [para for para in self.output_conditions.parameters()] paras = [para for para in self.middle_conditions.parameters()
] + [para for para in self.output_conditions.parameters()]
return paras return paras

View File

@@ -1,10 +1,12 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
class Image2DResBlockWithTV(nn.Module): class Image2DResBlockWithTV(nn.Module):
def __init__(self, dim, tdim, vdim): def __init__(self, dim, tdim, vdim):
super().__init__() super().__init__()
norm = lambda c: nn.GroupNorm(8, c) norm = lambda c: nn.GroupNorm(8, c) # noqa
self.time_embed = nn.Conv2d(tdim, dim, 1, 1) self.time_embed = nn.Conv2d(tdim, dim, 1, 1)
self.view_embed = nn.Conv2d(vdim, dim, 1, 1) self.view_embed = nn.Conv2d(vdim, dim, 1, 1)
self.conv = nn.Sequential( self.conv = nn.Sequential(
@@ -17,22 +19,28 @@ class Image2DResBlockWithTV(nn.Module):
) )
def forward(self, x, t, v): def forward(self, x, t, v):
return x+self.conv(x+self.time_embed(t)+self.view_embed(v)) return x + self.conv(x + self.time_embed(t) + self.view_embed(v))
class NoisyTargetViewEncoder(nn.Module): class NoisyTargetViewEncoder(nn.Module):
def __init__(self, time_embed_dim, viewpoint_dim, run_dim=16, output_dim=8):
def __init__(self,
time_embed_dim,
viewpoint_dim,
run_dim=16,
output_dim=8):
super().__init__() super().__init__()
self.init_conv = nn.Conv2d(4, run_dim, 3, 1, 1) self.init_conv = nn.Conv2d(4, run_dim, 3, 1, 1)
self.out_conv0 = Image2DResBlockWithTV(run_dim, time_embed_dim, viewpoint_dim) self.out_conv0 = Image2DResBlockWithTV(run_dim, time_embed_dim,
self.out_conv1 = Image2DResBlockWithTV(run_dim, time_embed_dim, viewpoint_dim) viewpoint_dim)
self.out_conv2 = Image2DResBlockWithTV(run_dim, time_embed_dim, viewpoint_dim) self.out_conv1 = Image2DResBlockWithTV(run_dim, time_embed_dim,
viewpoint_dim)
self.out_conv2 = Image2DResBlockWithTV(run_dim, time_embed_dim,
viewpoint_dim)
self.final_out = nn.Sequential( self.final_out = nn.Sequential(
nn.GroupNorm(8, run_dim), nn.GroupNorm(8, run_dim), nn.SiLU(True),
nn.SiLU(True), nn.Conv2d(run_dim, output_dim, 3, 1, 1))
nn.Conv2d(run_dim, output_dim, 3, 1, 1)
)
def forward(self, x, t, v): def forward(self, x, t, v):
B, DT = t.shape B, DT = t.shape
@@ -47,23 +55,33 @@ class NoisyTargetViewEncoder(nn.Module):
x = self.final_out(x) x = self.final_out(x)
return x return x
class SpatialUpTimeBlock(nn.Module): class SpatialUpTimeBlock(nn.Module):
def __init__(self, x_in_dim, t_in_dim, out_dim): def __init__(self, x_in_dim, t_in_dim, out_dim):
super().__init__() super().__init__()
norm_act = lambda c: nn.GroupNorm(8, c) norm_act = lambda c: nn.GroupNorm(8, c) # noqa
self.t_conv = nn.Conv3d(t_in_dim, x_in_dim, 1, 1) # 16 self.t_conv = nn.Conv3d(t_in_dim, x_in_dim, 1, 1) # 16
self.norm = norm_act(x_in_dim) self.norm = norm_act(x_in_dim)
self.silu = nn.SiLU(True) self.silu = nn.SiLU(True)
self.conv = nn.ConvTranspose3d(x_in_dim, out_dim, kernel_size=3, padding=1, output_padding=1, stride=2) self.conv = nn.ConvTranspose3d(
x_in_dim,
out_dim,
kernel_size=3,
padding=1,
output_padding=1,
stride=2)
def forward(self, x, t): def forward(self, x, t):
x = x + self.t_conv(t) x = x + self.t_conv(t)
return self.conv(self.silu(self.norm(x))) return self.conv(self.silu(self.norm(x)))
class SpatialTimeBlock(nn.Module): class SpatialTimeBlock(nn.Module):
def __init__(self, x_in_dim, t_in_dim, out_dim, stride): def __init__(self, x_in_dim, t_in_dim, out_dim, stride):
super().__init__() super().__init__()
norm_act = lambda c: nn.GroupNorm(8, c) norm_act = lambda c: nn.GroupNorm(8, c) # noqa
self.t_conv = nn.Conv3d(t_in_dim, x_in_dim, 1, 1) # 16 self.t_conv = nn.Conv3d(t_in_dim, x_in_dim, 1, 1) # 16
self.bn = norm_act(x_in_dim) self.bn = norm_act(x_in_dim)
self.silu = nn.SiLU(True) self.silu = nn.SiLU(True)
@@ -73,61 +91,65 @@ class SpatialTimeBlock(nn.Module):
x = x + self.t_conv(t) x = x + self.t_conv(t)
return self.conv(self.silu(self.bn(x))) return self.conv(self.silu(self.bn(x)))
class SpatialTime3DNet(nn.Module): class SpatialTime3DNet(nn.Module):
def __init__(self, time_dim=256, input_dim=128, dims=(32, 64, 128, 256)):
super().__init__()
d0, d1, d2, d3 = dims
dt = time_dim
self.init_conv = nn.Conv3d(input_dim, d0, 3, 1, 1) # 32 def __init__(self, time_dim=256, input_dim=128, dims=(32, 64, 128, 256)):
self.conv0 = SpatialTimeBlock(d0, dt, d0, stride=1) super().__init__()
d0, d1, d2, d3 = dims
dt = time_dim
self.conv1 = SpatialTimeBlock(d0, dt, d1, stride=2) self.init_conv = nn.Conv3d(input_dim, d0, 3, 1, 1) # 32
self.conv2_0 = SpatialTimeBlock(d1, dt, d1, stride=1) self.conv0 = SpatialTimeBlock(d0, dt, d0, stride=1)
self.conv2_1 = SpatialTimeBlock(d1, dt, d1, stride=1)
self.conv3 = SpatialTimeBlock(d1, dt, d2, stride=2) self.conv1 = SpatialTimeBlock(d0, dt, d1, stride=2)
self.conv4_0 = SpatialTimeBlock(d2, dt, d2, stride=1) self.conv2_0 = SpatialTimeBlock(d1, dt, d1, stride=1)
self.conv4_1 = SpatialTimeBlock(d2, dt, d2, stride=1) self.conv2_1 = SpatialTimeBlock(d1, dt, d1, stride=1)
self.conv5 = SpatialTimeBlock(d2, dt, d3, stride=2) self.conv3 = SpatialTimeBlock(d1, dt, d2, stride=2)
self.conv6_0 = SpatialTimeBlock(d3, dt, d3, stride=1) self.conv4_0 = SpatialTimeBlock(d2, dt, d2, stride=1)
self.conv6_1 = SpatialTimeBlock(d3, dt, d3, stride=1) self.conv4_1 = SpatialTimeBlock(d2, dt, d2, stride=1)
self.conv7 = SpatialUpTimeBlock(d3, dt, d2) self.conv5 = SpatialTimeBlock(d2, dt, d3, stride=2)
self.conv8 = SpatialUpTimeBlock(d2, dt, d1) self.conv6_0 = SpatialTimeBlock(d3, dt, d3, stride=1)
self.conv9 = SpatialUpTimeBlock(d1, dt, d0) self.conv6_1 = SpatialTimeBlock(d3, dt, d3, stride=1)
def forward(self, x, t): self.conv7 = SpatialUpTimeBlock(d3, dt, d2)
B, C = t.shape self.conv8 = SpatialUpTimeBlock(d2, dt, d1)
t = t.view(B, C, 1, 1, 1) self.conv9 = SpatialUpTimeBlock(d1, dt, d0)
x = self.init_conv(x) def forward(self, x, t):
conv0 = self.conv0(x, t) B, C = t.shape
t = t.view(B, C, 1, 1, 1)
x = self.conv1(conv0, t) x = self.init_conv(x)
x = self.conv2_0(x, t) conv0 = self.conv0(x, t)
conv2 = self.conv2_1(x, t)
x = self.conv3(conv2, t) x = self.conv1(conv0, t)
x = self.conv4_0(x, t) x = self.conv2_0(x, t)
conv4 = self.conv4_1(x, t) conv2 = self.conv2_1(x, t)
x = self.conv5(conv4, t) x = self.conv3(conv2, t)
x = self.conv6_0(x, t) x = self.conv4_0(x, t)
x = self.conv6_1(x, t) conv4 = self.conv4_1(x, t)
x = self.conv5(conv4, t)
x = self.conv6_0(x, t)
x = self.conv6_1(x, t)
x = conv4 + self.conv7(x, t)
x = conv2 + self.conv8(x, t)
x = conv0 + self.conv9(x, t)
return x
x = conv4 + self.conv7(x, t)
x = conv2 + self.conv8(x, t)
x = conv0 + self.conv9(x, t)
return x
class FrustumTVBlock(nn.Module): class FrustumTVBlock(nn.Module):
def __init__(self, x_dim, t_dim, v_dim, out_dim, stride): def __init__(self, x_dim, t_dim, v_dim, out_dim, stride):
super().__init__() super().__init__()
norm_act = lambda c: nn.GroupNorm(8, c) norm_act = lambda c: nn.GroupNorm(8, c) # noqa
self.t_conv = nn.Conv3d(t_dim, x_dim, 1, 1) # 16 self.t_conv = nn.Conv3d(t_dim, x_dim, 1, 1) # 16
self.v_conv = nn.Conv3d(v_dim, x_dim, 1, 1) # 16 self.v_conv = nn.Conv3d(v_dim, x_dim, 1, 1) # 16
self.bn = norm_act(x_dim) self.bn = norm_act(x_dim)
self.silu = nn.SiLU(True) self.silu = nn.SiLU(True)
self.conv = nn.Conv3d(x_dim, out_dim, 3, stride=stride, padding=1) self.conv = nn.Conv3d(x_dim, out_dim, 3, stride=stride, padding=1)
@@ -136,24 +158,34 @@ class FrustumTVBlock(nn.Module):
x = x + self.t_conv(t) + self.v_conv(v) x = x + self.t_conv(t) + self.v_conv(v)
return self.conv(self.silu(self.bn(x))) return self.conv(self.silu(self.bn(x)))
class FrustumTVUpBlock(nn.Module): class FrustumTVUpBlock(nn.Module):
def __init__(self, x_dim, t_dim, v_dim, out_dim): def __init__(self, x_dim, t_dim, v_dim, out_dim):
super().__init__() super().__init__()
norm_act = lambda c: nn.GroupNorm(8, c) norm_act = lambda c: nn.GroupNorm(8, c) # noqa
self.t_conv = nn.Conv3d(t_dim, x_dim, 1, 1) # 16 self.t_conv = nn.Conv3d(t_dim, x_dim, 1, 1) # 16
self.v_conv = nn.Conv3d(v_dim, x_dim, 1, 1) # 16 self.v_conv = nn.Conv3d(v_dim, x_dim, 1, 1) # 16
self.norm = norm_act(x_dim) self.norm = norm_act(x_dim)
self.silu = nn.SiLU(True) self.silu = nn.SiLU(True)
self.conv = nn.ConvTranspose3d(x_dim, out_dim, kernel_size=3, padding=1, output_padding=1, stride=2) self.conv = nn.ConvTranspose3d(
x_dim,
out_dim,
kernel_size=3,
padding=1,
output_padding=1,
stride=2)
def forward(self, x, t, v): def forward(self, x, t, v):
x = x + self.t_conv(t) + self.v_conv(v) x = x + self.t_conv(t) + self.v_conv(v)
return self.conv(self.silu(self.norm(x))) return self.conv(self.silu(self.norm(x)))
class FrustumTV3DNet(nn.Module): class FrustumTV3DNet(nn.Module):
def __init__(self, in_dim, t_dim, v_dim, dims=(32, 64, 128, 256)): def __init__(self, in_dim, t_dim, v_dim, dims=(32, 64, 128, 256)):
super().__init__() super().__init__()
self.conv0 = nn.Conv3d(in_dim, dims[0], 3, 1, 1) # 32 self.conv0 = nn.Conv3d(in_dim, dims[0], 3, 1, 1) # 32
self.conv1 = FrustumTVBlock(dims[0], t_dim, v_dim, dims[1], 2) self.conv1 = FrustumTVBlock(dims[0], t_dim, v_dim, dims[1], 2)
self.conv2 = FrustumTVBlock(dims[1], t_dim, v_dim, dims[1], 1) self.conv2 = FrustumTVBlock(dims[1], t_dim, v_dim, dims[1], 1)
@@ -169,10 +201,10 @@ class FrustumTV3DNet(nn.Module):
self.up2 = FrustumTVUpBlock(dims[1], t_dim, v_dim, dims[0]) self.up2 = FrustumTVUpBlock(dims[1], t_dim, v_dim, dims[0])
def forward(self, x, t, v): def forward(self, x, t, v):
B,DT = t.shape B, DT = t.shape
t = t.view(B,DT,1,1,1) t = t.view(B, DT, 1, 1, 1)
B,DV = v.shape B, DV = v.shape
v = v.view(B,DV,1,1,1) v = v.view(B, DV, 1, 1, 1)
b, _, d, h, w = x.shape b, _, d, h, w = x.shape
x0 = self.conv0(x) x0 = self.conv0(x)
@@ -183,4 +215,4 @@ class FrustumTV3DNet(nn.Module):
x2 = self.up0(x3, t, v) + x2 x2 = self.up0(x3, t, v) + x2
x1 = self.up1(x2, t, v) + x1 x1 = self.up1(x2, t, v) + x1
x0 = self.up2(x1, t, v) + x0 x0 = self.up2(x1, t, v) + x0
return {w: x0, w//2: x1, w//4: x2, w//8: x3} return {w: x0, w // 2: x1, w // 4: x2, w // 8: x3}

View File

@@ -10,13 +10,13 @@ def project_and_normalize(ref_grid, src_proj, length):
@param length: int @param length: int
@return: b, n, 2 @return: b, n, 2
""" """
src_grid = src_proj[:, :3, :3] @ ref_grid + src_proj[:, :3, 3:] # b 3 n src_grid = src_proj[:, :3, :3] @ ref_grid + src_proj[:, :3, 3:] # b 3 n
div_val = src_grid[:, -1:] div_val = src_grid[:, -1:]
div_val[div_val<1e-4] = 1e-4 div_val[div_val < 1e-4] = 1e-4
src_grid = src_grid[:, :2] / div_val # divide by depth (b, 2, n) src_grid = src_grid[:, :2] / div_val # divide by depth (b, 2, n)
src_grid[:, 0] = src_grid[:, 0]/((length - 1) / 2) - 1 # scale to -1~1 src_grid[:, 0] = src_grid[:, 0] / ((length - 1) / 2) - 1 # scale to -1~1
src_grid[:, 1] = src_grid[:, 1]/((length - 1) / 2) - 1 # scale to -1~1 src_grid[:, 1] = src_grid[:, 1] / ((length - 1) / 2) - 1 # scale to -1~1
src_grid = src_grid.permute(0, 2, 1) # (b, n, 2) src_grid = src_grid.permute(0, 2, 1) # (b, n, 2)
return src_grid return src_grid
@@ -29,38 +29,55 @@ def construct_project_matrix(x_ratio, y_ratio, Ks, poses):
@return: @return:
""" """
rfn = Ks.shape[0] rfn = Ks.shape[0]
scale_m = torch.tensor([x_ratio, y_ratio, 1.0], dtype=torch.float32, device=Ks.device) scale_m = torch.tensor([x_ratio, y_ratio, 1.0],
dtype=torch.float32,
device=Ks.device)
scale_m = torch.diag(scale_m) scale_m = torch.diag(scale_m)
ref_prj = scale_m[None, :, :] @ Ks @ poses # rfn,3,4 ref_prj = scale_m[None, :, :] @ Ks @ poses # rfn,3,4
pad_vals = torch.zeros([rfn, 1, 4], dtype=torch.float32, device=ref_prj.device) pad_vals = torch.zeros([rfn, 1, 4],
dtype=torch.float32,
device=ref_prj.device)
pad_vals[:, :, 3] = 1.0 pad_vals[:, :, 3] = 1.0
ref_prj = torch.cat([ref_prj, pad_vals], 1) # rfn,4,4 ref_prj = torch.cat([ref_prj, pad_vals], 1) # rfn,4,4
return ref_prj return ref_prj
def get_warp_coordinates(volume_xyz, warp_size, input_size, Ks, warp_pose): def get_warp_coordinates(volume_xyz, warp_size, input_size, Ks, warp_pose):
B, _, D, H, W = volume_xyz.shape B, _, D, H, W = volume_xyz.shape
ratio = warp_size / input_size ratio = warp_size / input_size
warp_proj = construct_project_matrix(ratio, ratio, Ks, warp_pose) # B,4,4 warp_proj = construct_project_matrix(ratio, ratio, Ks, warp_pose) # B,4,4
warp_coords = project_and_normalize(volume_xyz.view(B,3,D*H*W), warp_proj, warp_size).view(B, D, H, W, 2) warp_coords = project_and_normalize(
volume_xyz.view(B, 3, D * H * W), warp_proj,
warp_size).view(B, D, H, W, 2)
return warp_coords return warp_coords
def create_target_volume(depth_size, volume_size, input_image_size, pose_target, K, near=None, far=None): def create_target_volume(depth_size,
volume_size,
input_image_size,
pose_target,
K,
near=None,
far=None):
device, dtype = pose_target.device, pose_target.dtype device, dtype = pose_target.device, pose_target.dtype
# compute a depth range on the unit sphere # compute a depth range on the unit sphere
H, W, D, B = volume_size, volume_size, depth_size, pose_target.shape[0] H, W, D, B = volume_size, volume_size, depth_size, pose_target.shape[0]
if near is not None and far is not None : if near is not None and far is not None:
# near, far b,1,h,w # near, far b,1,h,w
depth_values = torch.linspace(0, 1, steps=depth_size).to(near.device).to(near.dtype) # d depth_values = torch.linspace(
depth_values = depth_values.view(1, D, 1, 1) # 1,d,1,1 0, 1, steps=depth_size).to(near.device).to(near.dtype) # d
depth_values = depth_values * (far - near) + near # b d h w depth_values = depth_values.view(1, D, 1, 1) # 1,d,1,1
depth_values = depth_values * (far - near) + near # b d h w
depth_values = depth_values.view(B, 1, D, H * W) depth_values = depth_values.view(B, 1, D, H * W)
else: else:
near, far = near_far_from_unit_sphere_using_camera_poses(pose_target) # b 1 near, far = near_far_from_unit_sphere_using_camera_poses(
depth_values = torch.linspace(0, 1, steps=depth_size).to(near.device).to(near.dtype) # d pose_target) # b 1
depth_values = depth_values[None,:,None] * (far[:,None,:] - near[:,None,:]) + near[:,None,:] # b d 1 depth_values = torch.linspace(
depth_values = depth_values.view(B, 1, D, 1).expand(B, 1, D, H*W) 0, 1, steps=depth_size).to(near.device).to(near.dtype) # d
depth_values = depth_values[None, :, None] * (
far[:, None, :] - near[:, None, :]) + near[:, None, :] # b d 1
depth_values = depth_values.view(B, 1, D, 1).expand(B, 1, D, H * W)
ratio = volume_size / input_image_size ratio = volume_size / input_image_size
@@ -68,20 +85,28 @@ def create_target_volume(depth_size, volume_size, input_image_size, pose_target,
# H, W, D, B = volume_size, volume_size, depth_values.shape[1], depth_values.shape[0] # H, W, D, B = volume_size, volume_size, depth_values.shape[1], depth_values.shape[0]
# creat mesh grid: note reference also means target # creat mesh grid: note reference also means target
ref_grid = create_meshgrid(H, W, normalized_coordinates=False) # (1, H, W, 2) ref_grid = create_meshgrid(
H, W, normalized_coordinates=False) # (1, H, W, 2)
ref_grid = ref_grid.to(device).to(dtype) ref_grid = ref_grid.to(device).to(dtype)
ref_grid = ref_grid.permute(0, 3, 1, 2) # (1, 2, H, W) ref_grid = ref_grid.permute(0, 3, 1, 2) # (1, 2, H, W)
ref_grid = ref_grid.reshape(1, 2, H*W) # (1, 2, H*W) ref_grid = ref_grid.reshape(1, 2, H * W) # (1, 2, H*W)
ref_grid = ref_grid.expand(B, -1, -1) # (B, 2, H*W) ref_grid = ref_grid.expand(B, -1, -1) # (B, 2, H*W)
ref_grid = torch.cat((ref_grid, torch.ones(B, 1, H*W, dtype=ref_grid.dtype, device=ref_grid.device)), dim=1) # (B, 3, H*W) ref_grid = torch.cat(
(ref_grid,
torch.ones(B, 1, H * W, dtype=ref_grid.dtype,
device=ref_grid.device)),
dim=1) # (B, 3, H*W)
ref_grid = ref_grid.unsqueeze(2) * depth_values # (B, 3, D, H*W) ref_grid = ref_grid.unsqueeze(2) * depth_values # (B, 3, D, H*W)
# unproject to space and transfer to world coordinates. # unproject to space and transfer to world coordinates.
Ks = K Ks = K
ref_proj = construct_project_matrix(ratio, ratio, Ks, pose_target) # B,4,4 ref_proj = construct_project_matrix(ratio, ratio, Ks, pose_target) # B,4,4
ref_proj_inv = torch.inverse(ref_proj) # B,4,4 ref_proj_inv = torch.inverse(ref_proj) # B,4,4
ref_grid = ref_proj_inv[:,:3,:3] @ ref_grid.view(B,3,D*H*W) + ref_proj_inv[:,:3,3:] # B,3,3 @ B,3,DHW + B,3,1 => B,3,DHW ref_grid = ref_proj_inv[:, :3, :3] @ ref_grid.view(
return ref_grid.reshape(B,3,D,H,W), depth_values.view(B,1,D,H,W) B, 3, D * H
* W) + ref_proj_inv[:, :3, 3:] # B,3,3 @ B,3,DHW + B,3,1 => B,3,DHW
return ref_grid.reshape(B, 3, D, H, W), depth_values.view(B, 1, D, H, W)
def near_far_from_unit_sphere_using_camera_poses(camera_poses): def near_far_from_unit_sphere_using_camera_poses(camera_poses):
""" """
@@ -90,14 +115,16 @@ def near_far_from_unit_sphere_using_camera_poses(camera_poses):
near: b,1 near: b,1
far: b,1 far: b,1
""" """
R_w2c = camera_poses[..., :3, :3] # b 3 3 R_w2c = camera_poses[..., :3, :3] # b 3 3
t_w2c = camera_poses[..., :3, 3:] # b 3 1 t_w2c = camera_poses[..., :3, 3:] # b 3 1
camera_origin = -R_w2c.permute(0,2,1) @ t_w2c # b 3 1 camera_origin = -R_w2c.permute(0, 2, 1) @ t_w2c # b 3 1
# R_w2c.T @ (0,0,1) = z_dir # R_w2c.T @ (0,0,1) = z_dir
camera_orient = R_w2c.permute(0,2,1)[...,:3,2:3] # b 3 1 camera_orient = R_w2c.permute(0, 2, 1)[..., :3, 2:3] # b 3 1
camera_origin, camera_orient = camera_origin[...,0], camera_orient[..., 0] # b 3 camera_origin, camera_orient = camera_origin[...,
a = torch.sum(camera_orient ** 2, dim=-1, keepdim=True) # b 1 0], camera_orient[...,
b = -torch.sum(camera_orient * camera_origin, dim=-1, keepdim=True) # b 1 0] # b 3
mid = b / a # b 1 a = torch.sum(camera_orient**2, dim=-1, keepdim=True) # b 1
b = -torch.sum(camera_orient * camera_origin, dim=-1, keepdim=True) # b 1
mid = b / a # b 1
near, far = mid - 1.0, mid + 1.0 near, far = mid - 1.0, mid + 1.0
return near, far return near, far

View File

@@ -1,11 +1,13 @@
from inspect import isfunction
import math import math
from inspect import isfunction
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat from einops import rearrange, repeat
from torch import einsum, nn
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.util import checkpoint from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.util import \
checkpoint
def exists(val): def exists(val):
@@ -13,7 +15,7 @@ def exists(val):
def uniq(arr): def uniq(arr):
return{el: True for el in arr}.keys() return {el: True for el in arr}.keys()
def default(val, d): def default(val, d):
@@ -35,6 +37,7 @@ def init_(tensor):
# feedforward # feedforward
class GEGLU(nn.Module): class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out): def __init__(self, dim_in, dim_out):
super().__init__() super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2) self.proj = nn.Linear(dim_in, dim_out * 2)
@@ -42,8 +45,11 @@ class GEGLU(nn.Module):
def forward(self, x): def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1) x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate) return x * F.gelu(gate)
# feedforward # feedforward
class ConvGEGLU(nn.Module): class ConvGEGLU(nn.Module):
def __init__(self, dim_in, dim_out): def __init__(self, dim_in, dim_out):
super().__init__() super().__init__()
self.proj = nn.Conv2d(dim_in, dim_out * 2, 1, 1, 0) self.proj = nn.Conv2d(dim_in, dim_out * 2, 1, 1, 0)
@@ -54,20 +60,16 @@ class ConvGEGLU(nn.Module):
class FeedForward(nn.Module): class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__() super().__init__()
inner_dim = int(dim * mult) inner_dim = int(dim * mult)
dim_out = default(dim_out, dim) dim_out = default(dim_out, dim)
project_in = nn.Sequential( project_in = nn.Sequential(nn.Linear(
nn.Linear(dim, inner_dim), dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential( self.net = nn.Sequential(project_in, nn.Dropout(dropout),
project_in, nn.Linear(inner_dim, dim_out))
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
)
def forward(self, x): def forward(self, x):
return self.net(x) return self.net(x)
@@ -83,54 +85,54 @@ def zero_module(module):
def Normalize(in_channels): def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) return torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class LinearAttention(nn.Module): class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32): def __init__(self, dim, heads=4, dim_head=32):
super().__init__() super().__init__()
self.heads = heads self.heads = heads
hidden_dim = dim_head * heads hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1) self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x): def forward(self, x):
b, c, h, w = x.shape b, c, h, w = x.shape
qkv = self.to_qkv(x) qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) q, k, v = rearrange(
k = k.softmax(dim=-1) qkv,
'b (qkv heads c) h w -> qkv b heads c (h w)',
heads=self.heads,
qkv=3)
k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v) context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q) out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) out = rearrange(
out,
'b heads c (h w) -> b (heads c) h w',
heads=self.heads,
h=h,
w=w)
return self.to_out(out) return self.to_out(out)
class SpatialSelfAttention(nn.Module): class SpatialSelfAttention(nn.Module):
def __init__(self, in_channels): def __init__(self, in_channels):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.norm = Normalize(in_channels) self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels, self.q = torch.nn.Conv2d(
in_channels, in_channels, in_channels, kernel_size=1, stride=1, padding=0)
kernel_size=1, self.k = torch.nn.Conv2d(
stride=1, in_channels, in_channels, kernel_size=1, stride=1, padding=0)
padding=0) self.v = torch.nn.Conv2d(
self.k = torch.nn.Conv2d(in_channels, in_channels, in_channels, kernel_size=1, stride=1, padding=0)
in_channels, self.proj_out = torch.nn.Conv2d(
kernel_size=1, in_channels, in_channels, kernel_size=1, stride=1, padding=0)
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x): def forward(self, x):
h_ = x h_ = x
@@ -140,7 +142,7 @@ class SpatialSelfAttention(nn.Module):
v = self.v(h_) v = self.v(h_)
# compute attention # compute attention
b,c,h,w = q.shape b, c, h, w = q.shape
q = rearrange(q, 'b c h w -> b (h w) c') q = rearrange(q, 'b c h w -> b (h w) c')
k = rearrange(k, 'b c h w -> b c (h w)') k = rearrange(k, 'b c h w -> b c (h w)')
w_ = torch.einsum('bij,bjk->bik', q, k) w_ = torch.einsum('bij,bjk->bik', q, k)
@@ -155,16 +157,22 @@ class SpatialSelfAttention(nn.Module):
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
h_ = self.proj_out(h_) h_ = self.proj_out(h_)
return x+h_ return x + h_
class CrossAttention(nn.Module): class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
def __init__(self,
query_dim,
context_dim=None,
heads=8,
dim_head=64,
dropout=0.):
super().__init__() super().__init__()
inner_dim = dim_head * heads inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim) context_dim = default(context_dim, query_dim)
self.scale = dim_head ** -0.5 self.scale = dim_head**-0.5
self.heads = heads self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
@@ -172,9 +180,7 @@ class CrossAttention(nn.Module):
self.to_v = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential( self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim), nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
nn.Dropout(dropout)
)
def forward(self, x, context=None, mask=None): def forward(self, x, context=None, mask=None):
h = self.heads h = self.heads
@@ -184,12 +190,13 @@ class CrossAttention(nn.Module):
k = self.to_k(context) k = self.to_k(context)
v = self.to_v(context) v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
if exists(mask): if exists(mask):
mask = mask>0 mask = mask > 0
mask = rearrange(mask, 'b ... -> b (...)') mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h) mask = repeat(mask, 'b j -> (b h) () j', h=h)
@@ -202,8 +209,15 @@ class CrossAttention(nn.Module):
out = rearrange(out, '(b h) n d -> b n (h d)', h=h) out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out) return self.to_out(out)
class BasicSpatialTransformer(nn.Module): class BasicSpatialTransformer(nn.Module):
def __init__(self, dim, n_heads, d_head, context_dim=None, checkpoint=True):
def __init__(self,
dim,
n_heads,
d_head,
context_dim=None,
checkpoint=True):
super().__init__() super().__init__()
inner_dim = n_heads * d_head inner_dim = n_heads * d_head
self.proj_in = nn.Sequential( self.proj_in = nn.Sequential(
@@ -212,7 +226,12 @@ class BasicSpatialTransformer(nn.Module):
nn.GroupNorm(8, inner_dim), nn.GroupNorm(8, inner_dim),
nn.ReLU(True), nn.ReLU(True),
) )
self.attn = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, context_dim=context_dim) # is a self-attention if not self.disable_self_attn self.attn = CrossAttention(
query_dim=inner_dim,
heads=n_heads,
dim_head=d_head,
context_dim=context_dim
) # is a self-attention if not self.disable_self_attn
self.out_conv = nn.Sequential( self.out_conv = nn.Sequential(
nn.GroupNorm(8, inner_dim), nn.GroupNorm(8, inner_dim),
nn.ReLU(True), nn.ReLU(True),
@@ -221,16 +240,18 @@ class BasicSpatialTransformer(nn.Module):
self.proj_out = nn.Sequential( self.proj_out = nn.Sequential(
nn.GroupNorm(8, inner_dim), nn.GroupNorm(8, inner_dim),
nn.ReLU(True), nn.ReLU(True),
zero_module(nn.Conv2d(inner_dim, dim, kernel_size=1, stride=1, padding=0)), zero_module(
nn.Conv2d(inner_dim, dim, kernel_size=1, stride=1, padding=0)),
) )
self.checkpoint = checkpoint self.checkpoint = checkpoint
def forward(self, x, context=None): def forward(self, x, context=None):
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) return checkpoint(self._forward, (x, context), self.parameters(),
self.checkpoint)
def _forward(self, x, context): def _forward(self, x, context):
# input # input
b,_,h,w = x.shape b, _, h, w = x.shape
x_in = x x_in = x
x = self.proj_in(x) x = self.proj_in(x)
@@ -245,44 +266,64 @@ class BasicSpatialTransformer(nn.Module):
x = self.proj_out(x) + x_in x = self.proj_out(x) + x_in
return x return x
class BasicTransformerBlock(nn.Module): class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, disable_self_attn=False):
def __init__(self,
dim,
n_heads,
d_head,
dropout=0.,
context_dim=None,
gated_ff=True,
checkpoint=True,
disable_self_attn=False):
super().__init__() super().__init__()
self.disable_self_attn = disable_self_attn self.disable_self_attn = disable_self_attn
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, self.attn1 = CrossAttention(
context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
context_dim=context_dim if self.disable_self_attn else
None) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, self.attn2 = CrossAttention(
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim) self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim) self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint self.checkpoint = checkpoint
def forward(self, x, context=None): def forward(self, x, context=None):
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) return checkpoint(self._forward, (x, context), self.parameters(),
self.checkpoint)
def _forward(self, x, context=None): def _forward(self, x, context=None):
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x x = self.attn1(
self.norm1(x),
context=context if self.disable_self_attn else None) + x
x = self.attn2(self.norm2(x), context=context) + x x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x x = self.ff(self.norm3(x)) + x
return x return x
class ConvFeedForward(nn.Module): class ConvFeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__() super().__init__()
inner_dim = int(dim * mult) inner_dim = int(dim * mult)
dim_out = default(dim_out, dim) dim_out = default(dim_out, dim)
project_in = nn.Sequential( project_in = nn.Sequential(
nn.Conv2d(dim, inner_dim, 1, 1, 0), nn.Conv2d(dim, inner_dim, 1, 1, 0),
nn.GELU() nn.GELU()) if not glu else ConvGEGLU(dim, inner_dim)
) if not glu else ConvGEGLU(dim, inner_dim)
self.net = nn.Sequential( self.net = nn.Sequential(project_in, nn.Dropout(dropout),
project_in, nn.Conv2d(inner_dim, dim_out, 1, 1, 0))
nn.Dropout(dropout),
nn.Conv2d(inner_dim, dim_out, 1, 1, 0)
)
def forward(self, x): def forward(self, x):
return self.net(x) return self.net(x)
@@ -296,31 +337,36 @@ class SpatialTransformer(nn.Module):
Then apply standard transformer action. Then apply standard transformer action.
Finally, reshape to image Finally, reshape to image
""" """
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None, def __init__(self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.,
context_dim=None,
disable_self_attn=False): disable_self_attn=False):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
inner_dim = n_heads * d_head inner_dim = n_heads * d_head
self.norm = Normalize(in_channels) self.norm = Normalize(in_channels)
self.proj_in = nn.Conv2d(in_channels, self.proj_in = nn.Conv2d(
inner_dim, in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
kernel_size=1,
stride=1,
padding=0)
self.transformer_blocks = nn.ModuleList( self.transformer_blocks = nn.ModuleList([
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim, BasicTransformerBlock(
disable_self_attn=disable_self_attn) inner_dim,
for d in range(depth)] n_heads,
) d_head,
dropout=dropout,
context_dim=context_dim,
disable_self_attn=disable_self_attn) for d in range(depth)
])
self.proj_out = zero_module(nn.Conv2d(inner_dim, self.proj_out = zero_module(
in_channels, nn.Conv2d(
kernel_size=1, inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
stride=1,
padding=0))
def forward(self, x, context=None): def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention # note: if no context is given, cross-attention defaults to self-attention

View File

@@ -1,6 +1,6 @@
import math
from abc import abstractmethod from abc import abstractmethod
from functools import partial from functools import partial
import math
from typing import Iterable from typing import Iterable
import numpy as np import numpy as np
@@ -8,16 +8,11 @@ import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from modelscope.models.cv.image_to_3d.ldm.modules.attention import \
SpatialTransformer
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.util import ( from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.util import (
checkpoint, avg_pool_nd, checkpoint, conv_nd, linear, normalization,
conv_nd, timestep_embedding, zero_module)
linear,
avg_pool_nd,
zero_module,
normalization,
timestep_embedding,
)
from modelscope.models.cv.image_to_3d.ldm.modules.attention import SpatialTransformer
from modelscope.models.cv.image_to_3d.ldm.util import exists from modelscope.models.cv.image_to_3d.ldm.util import exists
@@ -25,11 +20,11 @@ from modelscope.models.cv.image_to_3d.ldm.util import exists
def convert_module_to_f16(x): def convert_module_to_f16(x):
pass pass
def convert_module_to_f32(x): def convert_module_to_f32(x):
pass pass
## go
class AttentionPool2d(nn.Module): class AttentionPool2d(nn.Module):
""" """
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
@@ -43,7 +38,8 @@ class AttentionPool2d(nn.Module):
output_dim: int = None, output_dim: int = None,
): ):
super().__init__() super().__init__()
self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) self.positional_embedding = nn.Parameter(
th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5)
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
self.num_heads = embed_dim // num_heads_channels self.num_heads = embed_dim // num_heads_channels
@@ -98,37 +94,46 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions. upsampling occurs in the inner-two dimensions.
""" """
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): def __init__(self,
channels,
use_conv,
dims=2,
out_channels=None,
padding=1):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
self.use_conv = use_conv self.use_conv = use_conv
self.dims = dims self.dims = dims
if use_conv: if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) self.conv = conv_nd(
dims, self.channels, self.out_channels, 3, padding=padding)
def forward(self, x): def forward(self, x):
assert x.shape[1] == self.channels assert x.shape[1] == self.channels
if self.dims == 3: if self.dims == 3:
x = F.interpolate( x = F.interpolate(
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
) mode='nearest')
else: else:
x = F.interpolate(x, scale_factor=2, mode="nearest") x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.use_conv: if self.use_conv:
x = self.conv(x) x = self.conv(x)
return x return x
class TransposedUpsample(nn.Module): class TransposedUpsample(nn.Module):
'Learned 2x upsampling without padding' 'Learned 2x upsampling without padding'
def __init__(self, channels, out_channels=None, ks=5): def __init__(self, channels, out_channels=None, ks=5):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) self.up = nn.ConvTranspose2d(
self.channels, self.out_channels, kernel_size=ks, stride=2)
def forward(self,x): def forward(self, x):
return self.up(x) return self.up(x)
@@ -141,7 +146,12 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions. downsampling occurs in the inner-two dimensions.
""" """
def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): def __init__(self,
channels,
use_conv,
dims=2,
out_channels=None,
padding=1):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
@@ -150,8 +160,12 @@ class Downsample(nn.Module):
stride = 2 if dims != 3 else (1, 2, 2) stride = 2 if dims != 3 else (1, 2, 2)
if use_conv: if use_conv:
self.op = conv_nd( self.op = conv_nd(
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding dims,
) self.channels,
self.out_channels,
3,
stride=stride,
padding=padding)
else: else:
assert self.channels == self.out_channels assert self.channels == self.out_channels
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
@@ -220,7 +234,8 @@ class ResBlock(TimestepBlock):
nn.SiLU(), nn.SiLU(),
linear( linear(
emb_channels, emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels, 2 * self.out_channels
if use_scale_shift_norm else self.out_channels,
), ),
) )
self.out_layers = nn.Sequential( self.out_layers = nn.Sequential(
@@ -228,18 +243,18 @@ class ResBlock(TimestepBlock):
nn.SiLU(), nn.SiLU(),
nn.Dropout(p=dropout), nn.Dropout(p=dropout),
zero_module( zero_module(
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) conv_nd(
), dims, self.out_channels, self.out_channels, 3, padding=1)),
) )
if self.out_channels == channels: if self.out_channels == channels:
self.skip_connection = nn.Identity() self.skip_connection = nn.Identity()
elif use_conv: elif use_conv:
self.skip_connection = conv_nd( self.skip_connection = conv_nd(
dims, channels, self.out_channels, 3, padding=1 dims, channels, self.out_channels, 3, padding=1)
)
else: else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) self.skip_connection = conv_nd(dims, channels, self.out_channels,
1)
def forward(self, x, emb): def forward(self, x, emb):
""" """
@@ -248,10 +263,8 @@ class ResBlock(TimestepBlock):
:param emb: an [N x emb_channels] Tensor of timestep embeddings. :param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs. :return: an [N x C x ...] Tensor of outputs.
""" """
return checkpoint( return checkpoint(self._forward, (x, emb), self.parameters(),
self._forward, (x, emb), self.parameters(), self.use_checkpoint self.use_checkpoint)
)
def _forward(self, x, emb): def _forward(self, x, emb):
if self.updown: if self.updown:
@@ -265,7 +278,7 @@ class ResBlock(TimestepBlock):
emb_out = self.emb_layers(emb).type(h.dtype) emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape): while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None] emb_out = emb_out[..., None]
if self.use_scale_shift_norm: # False if self.use_scale_shift_norm: # False
out_norm, out_rest = self.out_layers[0], self.out_layers[1:] out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = th.chunk(emb_out, 2, dim=1) scale, shift = th.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift h = out_norm(h) * (1 + scale) + shift
@@ -298,7 +311,7 @@ class AttentionBlock(nn.Module):
else: else:
assert ( assert (
channels % num_head_channels == 0 channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" ), f'q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}'
self.num_heads = channels // num_head_channels self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint self.use_checkpoint = use_checkpoint
self.norm = normalization(channels) self.norm = normalization(channels)
@@ -313,8 +326,10 @@ class AttentionBlock(nn.Module):
self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, x): def forward(self, x):
return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! return checkpoint(
#return pt_checkpoint(self._forward, x) # pytorch self._forward, (x, ), self.parameters(), True
) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
# return pt_checkpoint(self._forward, x) # pytorch
def _forward(self, x): def _forward(self, x):
b, c, *spatial = x.shape b, c, *spatial = x.shape
@@ -341,7 +356,7 @@ def count_flops_attn(model, _x, y):
# We perform two matmuls with the same number of ops. # We perform two matmuls with the same number of ops.
# The first computes the weight matrix, the second computes # The first computes the weight matrix, the second computes
# the combination of the value vectors. # the combination of the value vectors.
matmul_ops = 2 * b * (num_spatial ** 2) * c matmul_ops = 2 * b * (num_spatial**2) * c
model.total_ops += th.DoubleTensor([matmul_ops]) model.total_ops += th.DoubleTensor([matmul_ops])
@@ -363,13 +378,14 @@ class QKVAttentionLegacy(nn.Module):
bs, width, length = qkv.shape bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0 assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads) ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(
ch, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch)) scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum( weight = th.einsum(
"bct,bcs->bts", q * scale, k * scale 'bct,bcs->bts', q * scale,
) # More stable with f16 than dividing afterwards k * scale) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v) a = th.einsum('bts,bcs->bct', weight, v)
return a.reshape(bs, -1, length) return a.reshape(bs, -1, length)
@staticmethod @staticmethod
@@ -398,12 +414,13 @@ class QKVAttention(nn.Module):
q, k, v = qkv.chunk(3, dim=1) q, k, v = qkv.chunk(3, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch)) scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum( weight = th.einsum(
"bct,bcs->bts", 'bct,bcs->bts',
(q * scale).view(bs * self.n_heads, ch, length), (q * scale).view(bs * self.n_heads, ch, length),
(k * scale).view(bs * self.n_heads, ch, length), (k * scale).view(bs * self.n_heads, ch, length),
) # More stable with f16 than dividing afterwards ) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) a = th.einsum('bts,bcs->bct', weight,
v.reshape(bs * self.n_heads, ch, length))
return a.reshape(bs, -1, length) return a.reshape(bs, -1, length)
@staticmethod @staticmethod
@@ -442,40 +459,43 @@ class UNetModel(nn.Module):
""" """
def __init__( def __init__(
self, self,
image_size, image_size,
in_channels, in_channels,
model_channels, model_channels,
out_channels, out_channels,
num_res_blocks, num_res_blocks,
attention_resolutions, attention_resolutions,
dropout=0, dropout=0,
channel_mult=(1, 2, 4, 8), channel_mult=(1, 2, 4, 8),
conv_resample=True, conv_resample=True,
dims=2, dims=2,
num_classes=None, num_classes=None,
use_checkpoint=False, use_checkpoint=False,
use_fp16=False, use_fp16=False,
num_heads=-1, num_heads=-1,
num_head_channels=-1, num_head_channels=-1,
num_heads_upsample=-1, num_heads_upsample=-1,
use_scale_shift_norm=False, use_scale_shift_norm=False,
resblock_updown=False, resblock_updown=False,
use_new_attention_order=False, use_new_attention_order=False,
use_spatial_transformer=False, # custom transformer support use_spatial_transformer=False, # custom transformer support
transformer_depth=1, # custom transformer support transformer_depth=1, # custom transformer support
context_dim=None, # custom transformer support context_dim=None, # custom transformer support
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
legacy=True, legacy=True,
disable_self_attentions=None, disable_self_attentions=None,
num_attention_blocks=None num_attention_blocks=None):
):
super().__init__() super().__init__()
if use_spatial_transformer: if use_spatial_transformer:
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' assert context_dim is not None, (
'Fool!! You forgot to include the dimension '
'of your cross-attention conditioning...')
if context_dim is not None: if context_dim is not None:
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' assert use_spatial_transformer, (
'Fool!! You forgot to use the spatial transformer '
'for your cross-attention conditioning...')
from omegaconf.listconfig import ListConfig from omegaconf.listconfig import ListConfig
if type(context_dim) == ListConfig: if type(context_dim) == ListConfig:
context_dim = list(context_dim) context_dim = list(context_dim)
@@ -497,20 +517,28 @@ class UNetModel(nn.Module):
self.num_res_blocks = len(channel_mult) * [num_res_blocks] self.num_res_blocks = len(channel_mult) * [num_res_blocks]
else: else:
if len(num_res_blocks) != len(channel_mult): if len(num_res_blocks) != len(channel_mult):
raise ValueError("provide num_res_blocks either as an int (globally constant) or " raise ValueError(
"as a list/tuple (per-level) with the same length as channel_mult") 'provide num_res_blocks either as an int (globally constant) or '
'as a list/tuple (per-level) with the same length as channel_mult'
)
self.num_res_blocks = num_res_blocks self.num_res_blocks = num_res_blocks
#self.num_res_blocks = num_res_blocks # self.num_res_blocks = num_res_blocks
if disable_self_attentions is not None: if disable_self_attentions is not None:
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
assert len(disable_self_attentions) == len(channel_mult) assert len(disable_self_attentions) == len(channel_mult)
if num_attention_blocks is not None: if num_attention_blocks is not None:
assert len(num_attention_blocks) == len(self.num_res_blocks) assert len(num_attention_blocks) == len(self.num_res_blocks)
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) assert all(
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " map(
f"This option has LESS priority than attention_resolutions {attention_resolutions}, " lambda i: self.num_res_blocks[i] >= num_attention_blocks[i
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " ],
f"attention will still not be set.") # todo: convert to warning range(len(num_attention_blocks))))
print(
f'Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. '
f'This option has LESS priority than attention_resolutions {attention_resolutions}, '
f'i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, '
f'attention will still not be set.'
) # todo: convert to warning
self.attention_resolutions = attention_resolutions self.attention_resolutions = attention_resolutions
self.dropout = dropout self.dropout = dropout
@@ -534,13 +562,10 @@ class UNetModel(nn.Module):
if self.num_classes is not None: if self.num_classes is not None:
self.label_emb = nn.Embedding(num_classes, time_embed_dim) self.label_emb = nn.Embedding(num_classes, time_embed_dim)
self.input_blocks = nn.ModuleList( self.input_blocks = nn.ModuleList([
[ TimestepEmbedSequential(
TimestepEmbedSequential( conv_nd(dims, in_channels, model_channels, 3, padding=1))
conv_nd(dims, in_channels, model_channels, 3, padding=1) ]) # 0
)
]
) # 0
self._feature_size = model_channels self._feature_size = model_channels
input_block_chans = [model_channels] input_block_chans = [model_channels]
ch = model_channels ch = model_channels
@@ -559,21 +584,22 @@ class UNetModel(nn.Module):
) )
] ]
ch = mult * model_channels ch = mult * model_channels
if ds in attention_resolutions: # always True if ds in attention_resolutions: # always True
if num_head_channels == -1: if num_head_channels == -1:
dim_head = ch // num_heads dim_head = ch // num_heads
else: else:
num_heads = ch // num_head_channels num_heads = ch // num_head_channels
dim_head = num_head_channels dim_head = num_head_channels
if legacy: if legacy:
#num_heads = 1 # num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
if exists(disable_self_attentions): if exists(disable_self_attentions):
disabled_sa = disable_self_attentions[level] disabled_sa = disable_self_attentions[level]
else: else:
disabled_sa = False disabled_sa = False
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: if not exists(num_attention_blocks
) or nr < num_attention_blocks[level]:
layers.append( layers.append(
AttentionBlock( AttentionBlock(
ch, ch,
@@ -581,11 +607,14 @@ class UNetModel(nn.Module):
num_heads=num_heads, num_heads=num_heads,
num_head_channels=dim_head, num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order, use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( ) if not use_spatial_transformer else
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, SpatialTransformer(
disable_self_attn=disabled_sa ch,
) num_heads,
) dim_head,
depth=transformer_depth,
context_dim=context_dim,
disable_self_attn=disabled_sa))
self.input_blocks.append(TimestepEmbedSequential(*layers)) self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch self._feature_size += ch
input_block_chans.append(ch) input_block_chans.append(ch)
@@ -602,12 +631,8 @@ class UNetModel(nn.Module):
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
down=True, down=True,
) ) if resblock_updown else Downsample(
if resblock_updown ch, conv_resample, dims=dims, out_channels=out_ch))
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch
)
)
) )
ch = out_ch ch = out_ch
input_block_chans.append(ch) input_block_chans.append(ch)
@@ -620,7 +645,7 @@ class UNetModel(nn.Module):
num_heads = ch // num_head_channels num_heads = ch // num_head_channels
dim_head = num_head_channels dim_head = num_head_channels
if legacy: if legacy:
#num_heads = 1 # num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
self.middle_block = TimestepEmbedSequential( self.middle_block = TimestepEmbedSequential(
ResBlock( ResBlock(
@@ -637,9 +662,13 @@ class UNetModel(nn.Module):
num_heads=num_heads, num_heads=num_heads,
num_head_channels=dim_head, num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order, use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn ) if not use_spatial_transformer else
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim SpatialTransformer( # always uses a self-attn
), ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim),
ResBlock( ResBlock(
ch, ch,
time_embed_dim, time_embed_dim,
@@ -674,14 +703,15 @@ class UNetModel(nn.Module):
num_heads = ch // num_head_channels num_heads = ch // num_head_channels
dim_head = num_head_channels dim_head = num_head_channels
if legacy: if legacy:
#num_heads = 1 # num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
if exists(disable_self_attentions): if exists(disable_self_attentions):
disabled_sa = disable_self_attentions[level] disabled_sa = disable_self_attentions[level]
else: else:
disabled_sa = False disabled_sa = False
if not exists(num_attention_blocks) or i < num_attention_blocks[level]: if not exists(num_attention_blocks
) or i < num_attention_blocks[level]:
layers.append( layers.append(
AttentionBlock( AttentionBlock(
ch, ch,
@@ -689,11 +719,14 @@ class UNetModel(nn.Module):
num_heads=num_heads_upsample, num_heads=num_heads_upsample,
num_head_channels=dim_head, num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order, use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( ) if not use_spatial_transformer else
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, SpatialTransformer(
disable_self_attn=disabled_sa ch,
) num_heads,
) dim_head,
depth=transformer_depth,
context_dim=context_dim,
disable_self_attn=disabled_sa))
if level and i == self.num_res_blocks[level]: if level and i == self.num_res_blocks[level]:
out_ch = ch out_ch = ch
layers.append( layers.append(
@@ -706,10 +739,8 @@ class UNetModel(nn.Module):
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
up=True, up=True,
) ) if resblock_updown else Upsample(
if resblock_updown ch, conv_resample, dims=dims, out_channels=out_ch))
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
)
ds //= 2 ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers)) self.output_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch self._feature_size += ch
@@ -717,14 +748,15 @@ class UNetModel(nn.Module):
self.out = nn.Sequential( self.out = nn.Sequential(
normalization(ch), normalization(ch),
nn.SiLU(), nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), zero_module(
conv_nd(dims, model_channels, out_channels, 3, padding=1)),
) )
if self.predict_codebook_ids: if self.predict_codebook_ids:
self.id_predictor = nn.Sequential( self.id_predictor = nn.Sequential(
normalization(ch), normalization(ch),
conv_nd(dims, model_channels, n_embed, 1), conv_nd(dims, model_channels, n_embed, 1),
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
) )
def convert_to_fp16(self): def convert_to_fp16(self):
""" """
@@ -742,7 +774,7 @@ class UNetModel(nn.Module):
self.middle_block.apply(convert_module_to_f32) self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32) self.output_blocks.apply(convert_module_to_f32)
def forward(self, x, timesteps=None, context=None, y=None,**kwargs): def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
""" """
Apply the model to an input batch. Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs. :param x: an [N x C x ...] Tensor of inputs.
@@ -753,18 +785,19 @@ class UNetModel(nn.Module):
""" """
assert (y is not None) == ( assert (y is not None) == (
self.num_classes is not None self.num_classes is not None
), "must specify y if and only if the model is class-conditional" ), 'must specify y if and only if the model is class-conditional'
hs = [] hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) # N t_emb = timestep_embedding(
emb = self.time_embed(t_emb) # timesteps, self.model_channels, repeat_only=False) # N
emb = self.time_embed(t_emb) #
if self.num_classes is not None: if self.num_classes is not None:
assert y.shape == (x.shape[0],) assert y.shape == (x.shape[0], )
emb = emb + self.label_emb(y) emb = emb + self.label_emb(y)
h = x.type(self.dtype) h = x.type(self.dtype)
for module in self.input_blocks: for module in self.input_blocks:
h = module(h, emb, context) # conv h = module(h, emb, context) # conv
hs.append(h) hs.append(h)
h = self.middle_block(h, emb, context) h = self.middle_block(h, emb, context)
for module in self.output_blocks: for module in self.output_blocks:
@@ -783,30 +816,28 @@ class EncoderUNetModel(nn.Module):
For usage, see UNet. For usage, see UNet.
""" """
def __init__( def __init__(self,
self, image_size,
image_size, in_channels,
in_channels, model_channels,
model_channels, out_channels,
out_channels, num_res_blocks,
num_res_blocks, attention_resolutions,
attention_resolutions, dropout=0,
dropout=0, channel_mult=(1, 2, 4, 8),
channel_mult=(1, 2, 4, 8), conv_resample=True,
conv_resample=True, dims=2,
dims=2, use_checkpoint=False,
use_checkpoint=False, use_fp16=False,
use_fp16=False, num_heads=1,
num_heads=1, num_head_channels=-1,
num_head_channels=-1, num_heads_upsample=-1,
num_heads_upsample=-1, use_scale_shift_norm=False,
use_scale_shift_norm=False, resblock_updown=False,
resblock_updown=False, use_new_attention_order=False,
use_new_attention_order=False, pool='adaptive',
pool="adaptive", *args,
*args, **kwargs):
**kwargs
):
super().__init__() super().__init__()
if num_heads_upsample == -1: if num_heads_upsample == -1:
@@ -833,13 +864,10 @@ class EncoderUNetModel(nn.Module):
linear(time_embed_dim, time_embed_dim), linear(time_embed_dim, time_embed_dim),
) )
self.input_blocks = nn.ModuleList( self.input_blocks = nn.ModuleList([
[ TimestepEmbedSequential(
TimestepEmbedSequential( conv_nd(dims, in_channels, model_channels, 3, padding=1))
conv_nd(dims, in_channels, model_channels, 3, padding=1) ])
)
]
)
self._feature_size = model_channels self._feature_size = model_channels
input_block_chans = [model_channels] input_block_chans = [model_channels]
ch = model_channels ch = model_channels
@@ -866,8 +894,7 @@ class EncoderUNetModel(nn.Module):
num_heads=num_heads, num_heads=num_heads,
num_head_channels=num_head_channels, num_head_channels=num_head_channels,
use_new_attention_order=use_new_attention_order, use_new_attention_order=use_new_attention_order,
) ))
)
self.input_blocks.append(TimestepEmbedSequential(*layers)) self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch self._feature_size += ch
input_block_chans.append(ch) input_block_chans.append(ch)
@@ -884,12 +911,8 @@ class EncoderUNetModel(nn.Module):
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
down=True, down=True,
) ) if resblock_updown else Downsample(
if resblock_updown ch, conv_resample, dims=dims, out_channels=out_ch))
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch
)
)
) )
ch = out_ch ch = out_ch
input_block_chans.append(ch) input_block_chans.append(ch)
@@ -923,7 +946,7 @@ class EncoderUNetModel(nn.Module):
) )
self._feature_size += ch self._feature_size += ch
self.pool = pool self.pool = pool
if pool == "adaptive": if pool == 'adaptive':
self.out = nn.Sequential( self.out = nn.Sequential(
normalization(ch), normalization(ch),
nn.SiLU(), nn.SiLU(),
@@ -931,22 +954,21 @@ class EncoderUNetModel(nn.Module):
zero_module(conv_nd(dims, ch, out_channels, 1)), zero_module(conv_nd(dims, ch, out_channels, 1)),
nn.Flatten(), nn.Flatten(),
) )
elif pool == "attention": elif pool == 'attention':
assert num_head_channels != -1 assert num_head_channels != -1
self.out = nn.Sequential( self.out = nn.Sequential(
normalization(ch), normalization(ch),
nn.SiLU(), nn.SiLU(),
AttentionPool2d( AttentionPool2d((image_size // ds), ch, num_head_channels,
(image_size // ds), ch, num_head_channels, out_channels out_channels),
),
) )
elif pool == "spatial": elif pool == 'spatial':
self.out = nn.Sequential( self.out = nn.Sequential(
nn.Linear(self._feature_size, 2048), nn.Linear(self._feature_size, 2048),
nn.ReLU(), nn.ReLU(),
nn.Linear(2048, self.out_channels), nn.Linear(2048, self.out_channels),
) )
elif pool == "spatial_v2": elif pool == 'spatial_v2':
self.out = nn.Sequential( self.out = nn.Sequential(
nn.Linear(self._feature_size, 2048), nn.Linear(self._feature_size, 2048),
normalization(2048), normalization(2048),
@@ -954,7 +976,7 @@ class EncoderUNetModel(nn.Module):
nn.Linear(2048, self.out_channels), nn.Linear(2048, self.out_channels),
) )
else: else:
raise NotImplementedError(f"Unexpected {pool} pooling") raise NotImplementedError(f'Unexpected {pool} pooling')
def convert_to_fp16(self): def convert_to_fp16(self):
""" """
@@ -977,20 +999,20 @@ class EncoderUNetModel(nn.Module):
:param timesteps: a 1-D batch of timesteps. :param timesteps: a 1-D batch of timesteps.
:return: an [N x K] Tensor of outputs. :return: an [N x K] Tensor of outputs.
""" """
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) emb = self.time_embed(
timestep_embedding(timesteps, self.model_channels))
results = [] results = []
h = x.type(self.dtype) h = x.type(self.dtype)
for module in self.input_blocks: for module in self.input_blocks:
h = module(h, emb) h = module(h, emb)
if self.pool.startswith("spatial"): if self.pool.startswith('spatial'):
results.append(h.type(x.dtype).mean(dim=(2, 3))) results.append(h.type(x.dtype).mean(dim=(2, 3)))
h = self.middle_block(h, emb) h = self.middle_block(h, emb)
if self.pool.startswith("spatial"): if self.pool.startswith('spatial'):
results.append(h.type(x.dtype).mean(dim=(2, 3))) results.append(h.type(x.dtype).mean(dim=(2, 3)))
h = th.cat(results, axis=-1) h = th.cat(results, axis=-1)
return self.out(h) return self.out(h)
else: else:
h = h.type(x.dtype) h = h.type(x.dtype)
return self.out(h) return self.out(h)

View File

@@ -7,50 +7,65 @@
# #
# thanks! # thanks!
import os
import math import math
import os
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np
from einops import repeat from einops import repeat
from modelscope.models.cv.image_to_3d.ldm.util import instantiate_from_config from modelscope.models.cv.image_to_3d.ldm.util import instantiate_from_config
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): def make_beta_schedule(schedule,
if schedule == "linear": n_timestep,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3):
if schedule == 'linear':
betas = ( betas = (
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 torch.linspace(
) linear_start**0.5,
linear_end**0.5,
n_timestep,
dtype=torch.float64)**2)
elif schedule == "cosine": elif schedule == 'cosine':
timesteps = ( timesteps = (
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep
) + cosine_s)
alphas = timesteps / (1 + cosine_s) * np.pi / 2 alphas = timesteps / (1 + cosine_s) * np.pi / 2
alphas = torch.cos(alphas).pow(2) alphas = torch.cos(alphas).pow(2)
alphas = alphas / alphas[0] alphas = alphas / alphas[0]
betas = 1 - alphas[1:] / alphas[:-1] betas = 1 - alphas[1:] / alphas[:-1]
betas = np.clip(betas, a_min=0, a_max=0.999) betas = np.clip(betas, a_min=0, a_max=0.999)
elif schedule == "sqrt_linear": elif schedule == 'sqrt_linear':
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) betas = torch.linspace(
elif schedule == "sqrt": linear_start, linear_end, n_timestep, dtype=torch.float64)
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 elif schedule == 'sqrt':
betas = torch.linspace(
linear_start, linear_end, n_timestep, dtype=torch.float64)**0.5
else: else:
raise ValueError(f"schedule '{schedule}' unknown.") raise ValueError(f"schedule '{schedule}' unknown.")
return betas.numpy() return betas.numpy()
def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): def make_ddim_timesteps(ddim_discr_method,
num_ddim_timesteps,
num_ddpm_timesteps,
verbose=True):
if ddim_discr_method == 'uniform': if ddim_discr_method == 'uniform':
c = num_ddpm_timesteps // num_ddim_timesteps c = num_ddpm_timesteps // num_ddim_timesteps
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
elif ddim_discr_method == 'quad': elif ddim_discr_method == 'quad':
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8),
num_ddim_timesteps))**2).astype(int)
else: else:
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') raise NotImplementedError(
f'There is no ddim discretization method called "{ddim_discr_method}"'
)
# assert ddim_timesteps.shape[0] == num_ddim_timesteps # assert ddim_timesteps.shape[0] == num_ddim_timesteps
# add one to get the final alpha values right (the ones from first scale to data during sampling) # add one to get the final alpha values right (the ones from first scale to data during sampling)
@@ -60,17 +75,27 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep
return steps_out return steps_out
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): def make_ddim_sampling_parameters(alphacums,
ddim_timesteps,
eta,
verbose=True):
# select alphas for computing the variance schedule # select alphas for computing the variance schedule
alphas = alphacums[ddim_timesteps] alphas = alphacums[ddim_timesteps]
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) alphas_prev = np.asarray([alphacums[0]]
+ alphacums[ddim_timesteps[:-1]].tolist())
# according the the formula provided in https://arxiv.org/abs/2010.02502 # according the the formula provided in https://arxiv.org/abs/2010.02502
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) # rewrite because of E125
tmp = (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
sigmas = (eta * np.sqrt(tmp))
if verbose: if verbose:
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') print(
print(f'For the chosen value of eta, which is {eta}, ' f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}'
f'this results in the following sigma_t schedule for ddim sampler {sigmas}') )
print(
f'For the chosen value of eta, which is {eta}, '
f'this results in the following sigma_t schedule for ddim sampler {sigmas}'
)
return sigmas, alphas, alphas_prev return sigmas, alphas, alphas_prev
@@ -96,7 +121,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
def extract_into_tensor(a, t, x_shape): def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape b, *_ = t.shape
out = a.gather(-1, t) out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1))) return out.reshape(b, *((1, ) * (len(x_shape) - 1)))
def checkpoint(func, inputs, params, flag): def checkpoint(func, inputs, params, flag):
@@ -117,6 +142,7 @@ def checkpoint(func, inputs, params, flag):
class CheckpointFunction(torch.autograd.Function): class CheckpointFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, run_function, length, *args): def forward(ctx, run_function, length, *args):
ctx.run_function = run_function ctx.run_function = run_function
@@ -129,7 +155,9 @@ class CheckpointFunction(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, *output_grads): def backward(ctx, *output_grads):
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] ctx.input_tensors = [
x.detach().requires_grad_(True) for x in ctx.input_tensors
]
with torch.enable_grad(): with torch.enable_grad():
# Fixes a bug where the first op in run_function modifies the # Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d # Tensor storage in place, which is not allowed for detach()'d
@@ -160,12 +188,14 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
if not repeat_only: if not repeat_only:
half = dim // 2 half = dim // 2
freqs = torch.exp( freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half -math.log(max_period)
).to(device=timesteps.device) * torch.arange(start=0, end=half, dtype=torch.float32)
/ half).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None] args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2: if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
else: else:
embedding = repeat(timesteps, 'b -> b d', d=dim) embedding = repeat(timesteps, 'b -> b d', d=dim)
return embedding return embedding
@@ -207,14 +237,17 @@ def normalization(channels):
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class SiLU(nn.Module): class SiLU(nn.Module):
def forward(self, x): def forward(self, x):
return x * torch.sigmoid(x) return x * torch.sigmoid(x)
class GroupNorm32(nn.GroupNorm): class GroupNorm32(nn.GroupNorm):
def forward(self, x): def forward(self, x):
return super().forward(x.float()).type(x.dtype) return super().forward(x.float()).type(x.dtype)
def conv_nd(dims, *args, **kwargs): def conv_nd(dims, *args, **kwargs):
""" """
Create a 1D, 2D, or 3D convolution module. Create a 1D, 2D, or 3D convolution module.
@@ -225,7 +258,7 @@ def conv_nd(dims, *args, **kwargs):
return nn.Conv2d(*args, **kwargs) return nn.Conv2d(*args, **kwargs)
elif dims == 3: elif dims == 3:
return nn.Conv3d(*args, **kwargs) return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}") raise ValueError(f'unsupported dimensions: {dims}')
def linear(*args, **kwargs): def linear(*args, **kwargs):
@@ -245,7 +278,7 @@ def avg_pool_nd(dims, *args, **kwargs):
return nn.AvgPool2d(*args, **kwargs) return nn.AvgPool2d(*args, **kwargs)
elif dims == 3: elif dims == 3:
return nn.AvgPool3d(*args, **kwargs) return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}") raise ValueError(f'unsupported dimensions: {dims}')
class HybridConditioner(nn.Module): class HybridConditioner(nn.Module):
@@ -253,7 +286,8 @@ class HybridConditioner(nn.Module):
def __init__(self, c_concat_config, c_crossattn_config): def __init__(self, c_concat_config, c_crossattn_config):
super().__init__() super().__init__()
self.concat_conditioner = instantiate_from_config(c_concat_config) self.concat_conditioner = instantiate_from_config(c_concat_config)
self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) self.crossattn_conditioner = instantiate_from_config(
c_crossattn_config)
def forward(self, c_concat, c_crossattn): def forward(self, c_concat, c_crossattn):
c_concat = self.concat_conditioner(c_concat) c_concat = self.concat_conditioner(c_concat)
@@ -262,6 +296,13 @@ class HybridConditioner(nn.Module):
def noise_like(shape, device, repeat=False): def noise_like(shape, device, repeat=False):
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
noise = lambda: torch.randn(shape, device=device) def repeat_noise():
return repeat_noise() if repeat else noise() return torch.randn((1, *shape[1:]),
device=device).repeat(shape[0],
*((1, ) * (len(shape) - 1)))
def noise():
return torch.randn(shape, device=device)
return repeat_noise() if repeat else noise()

View File

@@ -1,8 +1,9 @@
import torch
import numpy as np import numpy as np
import torch
class AbstractDistribution: class AbstractDistribution:
def sample(self): def sample(self):
raise NotImplementedError() raise NotImplementedError()
@@ -11,6 +12,7 @@ class AbstractDistribution:
class DiracDistribution(AbstractDistribution): class DiracDistribution(AbstractDistribution):
def __init__(self, value): def __init__(self, value):
self.value = value self.value = value
@@ -22,6 +24,7 @@ class DiracDistribution(AbstractDistribution):
class DiagonalGaussianDistribution(object): class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False): def __init__(self, parameters, deterministic=False):
self.parameters = parameters self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
@@ -30,10 +33,12 @@ class DiagonalGaussianDistribution(object):
self.std = torch.exp(0.5 * self.logvar) self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar) self.var = torch.exp(self.logvar)
if self.deterministic: if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) self.var = self.std = torch.zeros_like(
self.mean).to(device=self.parameters.device)
def sample(self): def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) x = self.mean + self.std * torch.randn(
self.mean.shape).to(device=self.parameters.device)
return x return x
def kl(self, other=None): def kl(self, other=None):
@@ -41,21 +46,22 @@ class DiagonalGaussianDistribution(object):
return torch.Tensor([0.]) return torch.Tensor([0.])
else: else:
if other is None: if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2) return 0.5 * torch.sum(
+ self.var - 1.0 - self.logvar, torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2, 3]) dim=[1, 2, 3])
else: else:
return 0.5 * torch.sum( return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar, + self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3]) dim=[1, 2, 3])
def nll(self, sample, dims=[1,2,3]): def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic: if self.deterministic:
return torch.Tensor([0.]) return torch.Tensor([0.])
logtwopi = np.log(2.0 * np.pi) logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum( return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, logtwopi + self.logvar
+ torch.pow(sample - self.mean, 2) / self.var,
dim=dims) dim=dims)
def mode(self): def mode(self):
@@ -64,7 +70,8 @@ class DiagonalGaussianDistribution(object):
def normal_kl(mean1, logvar1, mean2, logvar2): def normal_kl(mean1, logvar1, mean2, logvar2):
""" """
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 source: https://github.com/openai/guided-diffusion/blob/
27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
Compute the KL divergence between two gaussians. Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases. scalars, among other use cases.
@@ -74,7 +81,7 @@ def normal_kl(mean1, logvar1, mean2, logvar2):
if isinstance(obj, torch.Tensor): if isinstance(obj, torch.Tensor):
tensor = obj tensor = obj
break break
assert tensor is not None, "at least one argument must be a Tensor" assert tensor is not None, 'at least one argument must be a Tensor'
# Force variances to be Tensors. Broadcasting helps convert scalars to # Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp(). # Tensors, but it does not work for torch.exp().
@@ -83,10 +90,7 @@ def normal_kl(mean1, logvar1, mean2, logvar2):
for x in (logvar1, logvar2) for x in (logvar1, logvar2)
] ]
return 0.5 * ( # rewrite because of W504
-1.0 tmp = ((mean1 - mean2)**2) * torch.exp(-logvar2)
+ logvar2 return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + tmp
- logvar1 ) # noqa
+ torch.exp(logvar1 - logvar2)
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
)

View File

@@ -2,15 +2,17 @@ import hashlib
import os import os
import urllib import urllib
import warnings import warnings
from typing import Any, Union, List from typing import Any, List, Union
from pkg_resources import packaging
import torch import torch
from PIL import Image from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize from pkg_resources import packaging
from torchvision.transforms import (CenterCrop, Compose, Normalize, Resize,
ToTensor)
from tqdm import tqdm from tqdm import tqdm
from modelscope.models.cv.image_to_3d.ldm.modules.encoders.clip.model import build_model from modelscope.models.cv.image_to_3d.ldm.modules.encoders.clip.model import \
build_model
try: try:
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
@@ -18,23 +20,40 @@ try:
except ImportError: except ImportError:
BICUBIC = Image.BICUBIC BICUBIC = Image.BICUBIC
if packaging.version.parse(
torch.__version__) < packaging.version.parse('1.7.1'):
warnings.warn('PyTorch version 1.7.1 or higher is recommended')
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): __all__ = ['available_models', 'load']
warnings.warn("PyTorch version 1.7.1 or higher is recommended")
__all__ = ["available_models", "load"]
_MODELS = { _MODELS = {
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 'RN50':
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 'https://openaipublic.azureedge.net/clip/models/'
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 'afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt',
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 'RN101':
"RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 'https://openaipublic.azureedge.net/clip/models/'
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", '8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt',
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 'RN50x4':
"ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 'https://openaipublic.azureedge.net/clip/models/'
"ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", '7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt',
'RN50x16':
'https://openaipublic.azureedge.net/clip/models/'
'52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt',
'RN50x64':
'https://openaipublic.azureedge.net/clip/models/'
'be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt',
'ViT-B/32':
'https://openaipublic.azureedge.net/clip/models/'
'40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt',
'ViT-B/16':
'https://openaipublic.azureedge.net/clip/models/'
'5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt',
'ViT-L/14':
'https://openaipublic.azureedge.net/clip/models/'
'b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt',
'ViT-L/14@336px':
'https://openaipublic.azureedge.net/clip/models/'
'3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt',
} }
@@ -42,20 +61,30 @@ def _download(url: str, root: str):
os.makedirs(root, exist_ok=True) os.makedirs(root, exist_ok=True)
filename = os.path.basename(url) filename = os.path.basename(url)
expected_sha256 = url.split("/")[-2] expected_sha256 = url.split('/')[-2]
download_target = os.path.join(root, filename) download_target = os.path.join(root, filename)
if os.path.exists(download_target) and not os.path.isfile(download_target): if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file") raise RuntimeError(
f'{download_target} exists and is not a regular file')
if os.path.isfile(download_target): if os.path.isfile(download_target):
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: if hashlib.sha256(open(download_target,
'rb').read()).hexdigest() == expected_sha256:
return download_target return download_target
else: else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") warnings.warn(
f'{download_target} exists, but the SHA256 checksum does not match; re-downloading the file'
)
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: with urllib.request.urlopen(url) as source, open(download_target,
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 'wb') as output:
with tqdm(
total=int(source.info().get('Content-Length')),
ncols=80,
unit='iB',
unit_scale=True,
unit_divisor=1024) as loop:
while True: while True:
buffer = source.read(8192) buffer = source.read(8192)
if not buffer: if not buffer:
@@ -64,14 +93,17 @@ def _download(url: str, root: str):
output.write(buffer) output.write(buffer)
loop.update(len(buffer)) loop.update(len(buffer))
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: if hashlib.sha256(open(download_target,
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") 'rb').read()).hexdigest() != expected_sha256:
raise RuntimeError(
'Model has been downloaded but the SHA256 checksum does not not match'
)
return download_target return download_target
def _convert_image_to_rgb(image): def _convert_image_to_rgb(image):
return image.convert("RGB") return image.convert('RGB')
def _transform(n_px): def _transform(n_px):
@@ -80,7 +112,8 @@ def _transform(n_px):
CenterCrop(n_px), CenterCrop(n_px),
_convert_image_to_rgb, _convert_image_to_rgb,
ToTensor(), ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711)),
]) ])
@@ -89,7 +122,11 @@ def available_models() -> List[str]:
return list(_MODELS.keys()) return list(_MODELS.keys())
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): def load(name: str,
device: Union[str, torch.device] = 'cuda'
if torch.cuda.is_available() else 'cpu',
jit: bool = False,
download_root: str = None):
"""Load a CLIP model """Load a CLIP model
Parameters Parameters
@@ -115,37 +152,47 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
""" """
if name in _MODELS: if name in _MODELS:
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) model_path = _download(
_MODELS[name], download_root
or os.path.expanduser('~/.cache/clip'))
elif os.path.isfile(name): elif os.path.isfile(name):
model_path = name model_path = name
else: else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}") raise RuntimeError(
f'Model {name} not found; available models = {available_models()}')
with open(model_path, 'rb') as opened_file: with open(model_path, 'rb') as opened_file:
try: try:
# loading JIT archive # loading JIT archive
model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() model = torch.jit.load(
opened_file, map_location=device if jit else 'cpu').eval()
state_dict = None state_dict = None
except RuntimeError: except RuntimeError:
# loading saved state dict # loading saved state dict
if jit: if jit:
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") warnings.warn(
f'File {model_path} is not a JIT archive. Loading as a state dict instead'
)
jit = False jit = False
state_dict = torch.load(opened_file, map_location="cpu") state_dict = torch.load(opened_file, map_location='cpu')
if not jit: if not jit:
model = build_model(state_dict or model.state_dict()).to(device) model = build_model(state_dict or model.state_dict()).to(device)
if str(device) == "cpu": if str(device) == 'cpu':
model.float() model.float()
return model, _transform(model.visual.input_resolution) return model, _transform(model.visual.input_resolution)
# patch the device names # patch the device names
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) device_holder = torch.jit.trace(
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
device_node = [
n for n in device_holder.graph.findAllNodes('prim::Constant')
if 'Device' in repr(n)
][-1]
def _node_get(node: torch._C.Node, key: str): def _node_get(node: torch._C.Node, key: str):
"""Gets attributes of a node which is polymorphic over return type. """Gets attributes of a node which is polymorphic over return type.
From https://github.com/pytorch/pytorch/pull/82628 From https://github.com/pytorch/pytorch/pull/82628
""" """
sel = node.kindOf(key) sel = node.kindOf(key)
@@ -153,16 +200,17 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
def patch_device(module): def patch_device(module):
try: try:
graphs = [module.graph] if hasattr(module, "graph") else [] graphs = [module.graph] if hasattr(module, 'graph') else []
except RuntimeError: except RuntimeError:
graphs = [] graphs = []
if hasattr(module, "forward1"): if hasattr(module, 'forward1'):
graphs.append(module.forward1.graph) graphs.append(module.forward1.graph)
for graph in graphs: for graph in graphs:
for node in graph.findAllNodes("prim::Constant"): for node in graph.findAllNodes('prim::Constant'):
if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"): if 'value' in node.attributeNames() and str(
_node_get(node, 'value')).startswith('cuda'):
node.copyAttributes(device_node) node.copyAttributes(device_node)
model.apply(patch_device) model.apply(patch_device)
@@ -170,25 +218,28 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
patch_device(model.encode_text) patch_device(model.encode_text)
# patch dtype to float32 on CPU # patch dtype to float32 on CPU
if str(device) == "cpu": if str(device) == 'cpu':
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) float_holder = torch.jit.trace(
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] lambda: torch.ones([]).float(), example_inputs=[])
float_input = list(float_holder.graph.findNode('aten::to').inputs())[1]
float_node = float_input.node() float_node = float_input.node()
def patch_float(module): def patch_float(module):
try: try:
graphs = [module.graph] if hasattr(module, "graph") else [] graphs = [module.graph] if hasattr(module, 'graph') else []
except RuntimeError: except RuntimeError:
graphs = [] graphs = []
if hasattr(module, "forward1"): if hasattr(module, 'forward1'):
graphs.append(module.forward1.graph) graphs.append(module.forward1.graph)
for graph in graphs: for graph in graphs:
for node in graph.findAllNodes("aten::to"): for node in graph.findAllNodes('aten::to'):
inputs = list(node.inputs()) inputs = list(node.inputs())
for i in [1, 2]: # dtype can be the second or third argument to aten::to() for i in [
if _node_get(inputs[i].node(), "value") == 5: 1, 2
]: # dtype can be the second or third argument to aten::to()
if _node_get(inputs[i].node(), 'value') == 5:
inputs[i].node().copyAttributes(float_node) inputs[i].node().copyAttributes(float_node)
model.apply(patch_float) model.apply(patch_float)

View File

@@ -33,11 +33,16 @@ class Bottleneck(nn.Module):
if stride > 1 or inplanes != planes * Bottleneck.expansion: if stride > 1 or inplanes != planes * Bottleneck.expansion:
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
self.downsample = nn.Sequential(OrderedDict([ self.downsample = nn.Sequential(
("-1", nn.AvgPool2d(stride)), OrderedDict([('-1', nn.AvgPool2d(stride)),
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), ('0',
("1", nn.BatchNorm2d(planes * self.expansion)) nn.Conv2d(
])) inplanes,
planes * self.expansion,
1,
stride=1,
bias=False)),
('1', nn.BatchNorm2d(planes * self.expansion))]))
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
identity = x identity = x
@@ -56,9 +61,15 @@ class Bottleneck(nn.Module):
class AttentionPool2d(nn.Module): class AttentionPool2d(nn.Module):
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
def __init__(self,
spacial_dim: int,
embed_dim: int,
num_heads: int,
output_dim: int = None):
super().__init__() super().__init__()
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) self.positional_embedding = nn.Parameter(
torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim) self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim) self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim)
@@ -70,14 +81,17 @@ class AttentionPool2d(nn.Module):
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
x, _ = F.multi_head_attention_forward( x, _ = F.multi_head_attention_forward(
query=x[:1], key=x, value=x, query=x[:1],
key=x,
value=x,
embed_dim_to_check=x.shape[-1], embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads, num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight, q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight, k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight, v_proj_weight=self.v_proj.weight,
in_proj_weight=None, in_proj_weight=None,
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), in_proj_bias=torch.cat(
[self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None, bias_k=None,
bias_v=None, bias_v=None,
add_zero_attn=False, add_zero_attn=False,
@@ -86,8 +100,7 @@ class AttentionPool2d(nn.Module):
out_proj_bias=self.c_proj.bias, out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True, use_separate_proj_weight=True,
training=self.training, training=self.training,
need_weights=False need_weights=False)
)
return x.squeeze(0) return x.squeeze(0)
@@ -99,19 +112,27 @@ class ModifiedResNet(nn.Module):
- The final pooling layer is a QKV attention instead of an average pool - The final pooling layer is a QKV attention instead of an average pool
""" """
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): def __init__(self,
layers,
output_dim,
heads,
input_resolution=224,
width=64):
super().__init__() super().__init__()
self.output_dim = output_dim self.output_dim = output_dim
self.input_resolution = input_resolution self.input_resolution = input_resolution
# the 3-layer stem # the 3-layer stem
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) self.conv1 = nn.Conv2d(
3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(width // 2) self.bn1 = nn.BatchNorm2d(width // 2)
self.relu1 = nn.ReLU(inplace=True) self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) self.conv2 = nn.Conv2d(
width // 2, width // 2, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(width // 2) self.bn2 = nn.BatchNorm2d(width // 2)
self.relu2 = nn.ReLU(inplace=True) self.relu2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) self.conv3 = nn.Conv2d(
width // 2, width, kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(width) self.bn3 = nn.BatchNorm2d(width)
self.relu3 = nn.ReLU(inplace=True) self.relu3 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(2) self.avgpool = nn.AvgPool2d(2)
@@ -124,7 +145,8 @@ class ModifiedResNet(nn.Module):
self.layer4 = self._make_layer(width * 8, layers[3], stride=2) self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
embed_dim = width * 32 # the ResNet feature dimension embed_dim = width * 32 # the ResNet feature dimension
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim,
heads, output_dim)
def _make_layer(self, planes, blocks, stride=1): def _make_layer(self, planes, blocks, stride=1):
layers = [Bottleneck(self._inplanes, planes, stride)] layers = [Bottleneck(self._inplanes, planes, stride)]
@@ -136,6 +158,7 @@ class ModifiedResNet(nn.Module):
return nn.Sequential(*layers) return nn.Sequential(*layers)
def forward(self, x): def forward(self, x):
def stem(x): def stem(x):
x = self.relu1(self.bn1(self.conv1(x))) x = self.relu1(self.bn1(self.conv1(x)))
x = self.relu2(self.bn2(self.conv2(x))) x = self.relu2(self.bn2(self.conv2(x)))
@@ -164,27 +187,34 @@ class LayerNorm(nn.LayerNorm):
class QuickGELU(nn.Module): class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x) return x * torch.sigmoid(1.702 * x)
class ResidualAttentionBlock(nn.Module): class ResidualAttentionBlock(nn.Module):
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
def __init__(self,
d_model: int,
n_head: int,
attn_mask: torch.Tensor = None):
super().__init__() super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head) self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model) self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(OrderedDict([ self.mlp = nn.Sequential(
("c_fc", nn.Linear(d_model, d_model * 4)), OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)),
("gelu", QuickGELU()), ('gelu', QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model)) ('c_proj', nn.Linear(d_model * 4, d_model))]))
]))
self.ln_2 = LayerNorm(d_model) self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask self.attn_mask = attn_mask
def attention(self, x: torch.Tensor): def attention(self, x: torch.Tensor):
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None self.attn_mask = self.attn_mask.to(
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] dtype=x.dtype,
device=x.device) if self.attn_mask is not None else None
return self.attn(
x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x)) x = x + self.attention(self.ln_1(x))
@@ -193,26 +223,42 @@ class ResidualAttentionBlock(nn.Module):
class Transformer(nn.Module): class Transformer(nn.Module):
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
def __init__(self,
width: int,
layers: int,
heads: int,
attn_mask: torch.Tensor = None):
super().__init__() super().__init__()
self.width = width self.width = width
self.layers = layers self.layers = layers
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) self.resblocks = nn.Sequential(*[
ResidualAttentionBlock(width, heads, attn_mask)
for _ in range(layers)
])
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
return self.resblocks(x) return self.resblocks(x)
class VisionTransformer(nn.Module): class VisionTransformer(nn.Module):
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
def __init__(self, input_resolution: int, patch_size: int, width: int,
layers: int, heads: int, output_dim: int):
super().__init__() super().__init__()
self.input_resolution = input_resolution self.input_resolution = input_resolution
self.output_dim = output_dim self.output_dim = output_dim
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) self.conv1 = nn.Conv2d(
in_channels=3,
out_channels=width,
kernel_size=patch_size,
stride=patch_size,
bias=False)
scale = width ** -0.5 scale = width**-0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width)) self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) self.positional_embedding = nn.Parameter(scale * torch.randn(
(input_resolution // patch_size)**2 + 1, width))
self.ln_pre = LayerNorm(width) self.ln_pre = LayerNorm(width)
self.transformer = Transformer(width, layers, heads) self.transformer = Transformer(width, layers, heads)
@@ -222,9 +268,15 @@ class VisionTransformer(nn.Module):
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
x = self.conv1(x) # shape = [*, width, grid, grid] x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.reshape(x.shape[0], x.shape[1],
-1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
# rewrite because of E126
tmp = self.class_embedding.to(x.dtype) + torch.zeros(
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device) # noqs
x = torch.cat([tmp, x], dim=1)
# shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype) x = x + self.positional_embedding.to(x.dtype)
x = self.ln_pre(x) x = self.ln_pre(x)
@@ -241,20 +293,21 @@ class VisionTransformer(nn.Module):
class CLIP(nn.Module): class CLIP(nn.Module):
def __init__(self,
embed_dim: int, def __init__(
# vision self,
image_resolution: int, embed_dim: int,
vision_layers: Union[Tuple[int, int, int, int], int], # vision
vision_width: int, image_resolution: int,
vision_patch_size: int, vision_layers: Union[Tuple[int, int, int, int], int],
# text vision_width: int,
context_length: int, vision_patch_size: int,
vocab_size: int, # text
transformer_width: int, context_length: int,
transformer_heads: int, vocab_size: int,
transformer_layers: int transformer_width: int,
): transformer_heads: int,
transformer_layers: int):
super().__init__() super().__init__()
self.context_length = context_length self.context_length = context_length
@@ -266,8 +319,7 @@ class CLIP(nn.Module):
output_dim=embed_dim, output_dim=embed_dim,
heads=vision_heads, heads=vision_heads,
input_resolution=image_resolution, input_resolution=image_resolution,
width=vision_width width=vision_width)
)
else: else:
vision_heads = vision_width // 64 vision_heads = vision_width // 64
self.visual = VisionTransformer( self.visual = VisionTransformer(
@@ -276,22 +328,22 @@ class CLIP(nn.Module):
width=vision_width, width=vision_width,
layers=vision_layers, layers=vision_layers,
heads=vision_heads, heads=vision_heads,
output_dim=embed_dim output_dim=embed_dim)
)
self.transformer = Transformer( self.transformer = Transformer(
width=transformer_width, width=transformer_width,
layers=transformer_layers, layers=transformer_layers,
heads=transformer_heads, heads=transformer_heads,
attn_mask=self.build_attention_mask() attn_mask=self.build_attention_mask())
)
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.token_embedding = nn.Embedding(vocab_size, transformer_width) self.token_embedding = nn.Embedding(vocab_size, transformer_width)
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) self.positional_embedding = nn.Parameter(
torch.empty(self.context_length, transformer_width))
self.ln_final = LayerNorm(transformer_width) self.ln_final = LayerNorm(transformer_width)
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) self.text_projection = nn.Parameter(
torch.empty(transformer_width, embed_dim))
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.initialize_parameters() self.initialize_parameters()
@@ -302,20 +354,24 @@ class CLIP(nn.Module):
if isinstance(self.visual, ModifiedResNet): if isinstance(self.visual, ModifiedResNet):
if self.visual.attnpool is not None: if self.visual.attnpool is not None:
std = self.visual.attnpool.c_proj.in_features ** -0.5 std = self.visual.attnpool.c_proj.in_features**-0.5
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: for resnet_block in [
self.visual.layer1, self.visual.layer2, self.visual.layer3,
self.visual.layer4
]:
for name, param in resnet_block.named_parameters(): for name, param in resnet_block.named_parameters():
if name.endswith("bn3.weight"): if name.endswith('bn3.weight'):
nn.init.zeros_(param) nn.init.zeros_(param)
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) proj_std = (self.transformer.width**-0.5) * (
attn_std = self.transformer.width ** -0.5 (2 * self.transformer.layers)**-0.5)
fc_std = (2 * self.transformer.width) ** -0.5 attn_std = self.transformer.width**-0.5
fc_std = (2 * self.transformer.width)**-0.5
for block in self.transformer.resblocks: for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std) nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std) nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
@@ -323,13 +379,14 @@ class CLIP(nn.Module):
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
if self.text_projection is not None: if self.text_projection is not None:
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) nn.init.normal_(
self.text_projection, std=self.transformer.width**-0.5)
def build_attention_mask(self): def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the vision tokens # lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf # pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length) mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf")) mask.fill_(float('-inf'))
mask.triu_(1) # zero out the lower diagonal mask.triu_(1) # zero out the lower diagonal
return mask return mask
@@ -341,7 +398,8 @@ class CLIP(nn.Module):
return self.visual(image.type(self.dtype)) return self.visual(image.type(self.dtype))
def encode_text(self, text): def encode_text(self, text):
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] x = self.token_embedding(text).type(
self.dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.type(self.dtype) x = x + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0, 2) # NLD -> LND x = x.permute(1, 0, 2) # NLD -> LND
@@ -351,7 +409,8 @@ class CLIP(nn.Module):
# x.shape = [batch_size, n_ctx, transformer.width] # x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence) # take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection x = x[torch.arange(x.shape[0]),
text.argmax(dim=-1)] @ self.text_projection
return x return x
@@ -360,7 +419,8 @@ class CLIP(nn.Module):
text_features = self.encode_text(text) text_features = self.encode_text(text)
# normalized features # normalized features
image_features = image_features / image_features.norm(dim=1, keepdim=True) image_features = image_features / image_features.norm(
dim=1, keepdim=True)
text_features = text_features / text_features.norm(dim=1, keepdim=True) text_features = text_features / text_features.norm(dim=1, keepdim=True)
# cosine similarity as logits # cosine similarity as logits
@@ -375,21 +435,24 @@ class CLIP(nn.Module):
def convert_weights(model: nn.Module): def convert_weights(model: nn.Module):
"""Convert applicable model parameters to fp16""" """Convert applicable model parameters to fp16"""
def _convert_weights_to_fp16(l): def _convert_weights_to_fp16(_l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): if isinstance(_l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
l.weight.data = l.weight.data.half() _l.weight.data = _l.weight.data.half()
if l.bias is not None: if _l.bias is not None:
l.bias.data = l.bias.data.half() _l.bias.data = _l.bias.data.half()
if isinstance(l, nn.MultiheadAttention): if isinstance(_l, nn.MultiheadAttention):
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: for attr in [
tensor = getattr(l, attr) *[f'{s}_proj_weight' for s in ['in', 'q', 'k', 'v']],
'in_proj_bias', 'bias_k', 'bias_v'
]:
tensor = getattr(_l, attr)
if tensor is not None: if tensor is not None:
tensor.data = tensor.data.half() tensor.data = tensor.data.half()
for name in ["text_projection", "proj"]: for name in ['text_projection', 'proj']:
if hasattr(l, name): if hasattr(_l, name):
attr = getattr(l, name) attr = getattr(_l, name)
if attr is not None: if attr is not None:
attr.data = attr.data.half() attr.data = attr.data.half()
@@ -397,37 +460,51 @@ def convert_weights(model: nn.Module):
def build_model(state_dict: dict): def build_model(state_dict: dict):
vit = "visual.proj" in state_dict vit = 'visual.proj' in state_dict
if vit: if vit:
vision_width = state_dict["visual.conv1.weight"].shape[0] vision_width = state_dict['visual.conv1.weight'].shape[0]
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) vision_layers = len([
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] k for k in state_dict.keys()
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) if k.startswith('visual.') and k.endswith('.attn.in_proj_weight')
])
vision_patch_size = state_dict['visual.conv1.weight'].shape[-1]
grid_size = round(
(state_dict['visual.positional_embedding'].shape[0] - 1)**0.5)
image_resolution = vision_patch_size * grid_size image_resolution = vision_patch_size * grid_size
else: else:
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] counts: list = [
len(
set(
k.split('.')[2] for k in state_dict
if k.startswith(f'visual.layer{b}')))
for b in [1, 2, 3, 4]
]
vision_layers = tuple(counts) vision_layers = tuple(counts)
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] vision_width = state_dict['visual.layer1.0.conv1.weight'].shape[0]
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) output_width = round(
(state_dict['visual.attnpool.positional_embedding'].shape[0]
- 1)**0.5)
vision_patch_size = None vision_patch_size = None
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] assert output_width**2 + 1 == state_dict[
'visual.attnpool.positional_embedding'].shape[0]
image_resolution = output_width * 32 image_resolution = output_width * 32
embed_dim = state_dict["text_projection"].shape[1] embed_dim = state_dict['text_projection'].shape[1]
context_length = state_dict["positional_embedding"].shape[0] context_length = state_dict['positional_embedding'].shape[0]
vocab_size = state_dict["token_embedding.weight"].shape[0] vocab_size = state_dict['token_embedding.weight'].shape[0]
transformer_width = state_dict["ln_final.weight"].shape[0] transformer_width = state_dict['ln_final.weight'].shape[0]
transformer_heads = transformer_width // 64 transformer_heads = transformer_width // 64
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) transformer_layers = len(
set(
k.split('.')[2] for k in state_dict
if k.startswith('transformer.resblocks')))
model = CLIP( model = CLIP(embed_dim, image_resolution, vision_layers, vision_width,
embed_dim, vision_patch_size, context_length, vocab_size,
image_resolution, vision_layers, vision_width, vision_patch_size, transformer_width, transformer_heads, transformer_layers)
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
)
for key in ["input_resolution", "context_length", "vocab_size"]: for key in ['input_resolution', 'context_length', 'vocab_size']:
if key in state_dict: if key in state_dict:
del state_dict[key] del state_dict[key]

View File

@@ -9,7 +9,9 @@ import regex as re
@lru_cache() @lru_cache()
def default_bpe(): def default_bpe():
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") return os.path.join(
os.path.dirname(os.path.abspath(__file__)),
'bpe_simple_vocab_16e6.txt.gz')
@lru_cache() @lru_cache()
@@ -23,13 +25,17 @@ def bytes_to_unicode():
To avoid that, we want lookup tables between utf-8 bytes and unicode strings. To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on. And avoids mapping to whitespace/control characters the bpe code barfs on.
""" """
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) bs = list(range(ord('!'),
ord('~') + 1)) + list(range(
ord('¡'),
ord('¬') + 1)) + list(range(ord('®'),
ord('ÿ') + 1))
cs = bs[:] cs = bs[:]
n = 0 n = 0
for b in range(2**8): for b in range(2**8):
if b not in bs: if b not in bs:
bs.append(b) bs.append(b)
cs.append(2**8+n) cs.append(2**8 + n)
n += 1 n += 1
cs = [chr(n) for n in cs] cs = [chr(n) for n in cs]
return dict(zip(bs, cs)) return dict(zip(bs, cs))
@@ -60,34 +66,41 @@ def whitespace_clean(text):
class SimpleTokenizer(object): class SimpleTokenizer(object):
def __init__(self, bpe_path: str = default_bpe()): def __init__(self, bpe_path: str = default_bpe()):
self.byte_encoder = bytes_to_unicode() self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') merges = gzip.open(bpe_path).read().decode('utf-8').split('\n')
merges = merges[1:49152-256-2+1] merges = merges[1:49152 - 256 - 2 + 1]
merges = [tuple(merge.split()) for merge in merges] merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values()) vocab = list(bytes_to_unicode().values())
vocab = vocab + [v+'</w>' for v in vocab] vocab = vocab + [v + '</w>' for v in vocab]
for merge in merges: for merge in merges:
vocab.append(''.join(merge)) vocab.append(''.join(merge))
vocab.extend(['<|startoftext|>', '<|endoftext|>']) vocab.extend(['<|startoftext|>', '<|endoftext|>'])
self.encoder = dict(zip(vocab, range(len(vocab)))) self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()} self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges)))) self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} self.cache = {
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) '<|startoftext|>': '<|startoftext|>',
'<|endoftext|>': '<|endoftext|>'
}
self.pat = re.compile(
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
re.IGNORECASE)
def bpe(self, token): def bpe(self, token):
if token in self.cache: if token in self.cache:
return self.cache[token] return self.cache[token]
word = tuple(token[:-1]) + ( token[-1] + '</w>',) word = tuple(token[:-1]) + (token[-1] + '</w>', )
pairs = get_pairs(word) pairs = get_pairs(word)
if not pairs: if not pairs:
return token+'</w>' return token + '</w>'
while True: while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) bigram = min(
pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks: if bigram not in self.bpe_ranks:
break break
first, second = bigram first, second = bigram
@@ -98,12 +111,13 @@ class SimpleTokenizer(object):
j = word.index(first, i) j = word.index(first, i)
new_word.extend(word[i:j]) new_word.extend(word[i:j])
i = j i = j
except: except Exception:
new_word.extend(word[i:]) new_word.extend(word[i:])
break break
if word[i] == first and i < len(word)-1 and word[i+1] == second: if word[i] == first and i < len(word) - 1 and word[
new_word.append(first+second) i + 1] == second:
new_word.append(first + second)
i += 2 i += 2
else: else:
new_word.append(word[i]) new_word.append(word[i])
@@ -122,11 +136,14 @@ class SimpleTokenizer(object):
bpe_tokens = [] bpe_tokens = []
text = whitespace_clean(basic_clean(text)).lower() text = whitespace_clean(basic_clean(text)).lower()
for token in re.findall(self.pat, text): for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) token = ''.join(self.byte_encoder[b]
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token]
for bpe_token in self.bpe(token).split(' '))
return bpe_tokens return bpe_tokens
def decode(self, tokens): def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens]) text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ') text = bytearray([self.byte_decoder[c] for c in text]).decode(
'utf-8', errors='replace').replace('</w>', ' ')
return text return text

View File

@@ -1,28 +1,45 @@
import random
from functools import partial
import kornia
import kornia.augmentation as K
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np import torch.nn.functional as F
from functools import partial from torchvision import transforms
import kornia from transformers import (CLIPTextModel, CLIPTokenizer, CLIPVisionModel,
T5EncoderModel, T5Tokenizer)
from modelscope.models.cv.image_to_3d.ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.util import (
from modelscope.models.cv.image_to_3d.ldm.util import default extract_into_tensor, make_beta_schedule, noise_like)
# import clip # import clip
from modelscope.models.cv.image_to_3d.ldm.modules.encoders import clip from modelscope.models.cv.image_to_3d.ldm.modules.encoders import clip
# TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
from modelscope.models.cv.image_to_3d.ldm.modules.x_transformer import (
Encoder, TransformerWrapper)
from modelscope.models.cv.image_to_3d.ldm.thirdp.psp.id_loss import IDFeatures
from modelscope.models.cv.image_to_3d.ldm.util import (default,
instantiate_from_config)
class AbstractEncoder(nn.Module): class AbstractEncoder(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def encode(self, *args, **kwargs): def encode(self, *args, **kwargs):
raise NotImplementedError raise NotImplementedError
class IdentityEncoder(AbstractEncoder): class IdentityEncoder(AbstractEncoder):
def encode(self, x): def encode(self, x):
return x return x
class FaceClipEncoder(AbstractEncoder): class FaceClipEncoder(AbstractEncoder):
def __init__(self, augment=True, retreival_key=None): def __init__(self, augment=True, retreival_key=None):
super().__init__() super().__init__()
self.encoder = FrozenCLIPImageEmbedder() self.encoder = FrozenCLIPImageEmbedder()
@@ -35,16 +52,16 @@ class FaceClipEncoder(AbstractEncoder):
x_offset = 125 x_offset = 125
if self.retreival_key: if self.retreival_key:
# Assumes retrieved image are packed into the second half of channels # Assumes retrieved image are packed into the second half of channels
face = img[:,3:,190:440,x_offset:(512-x_offset)] face = img[:, 3:, 190:440, x_offset:(512 - x_offset)]
other = img[:,:3,...].clone() other = img[:, :3, ...].clone()
else: else:
face = img[:,:,190:440,x_offset:(512-x_offset)] face = img[:, :, 190:440, x_offset:(512 - x_offset)]
other = img.clone() other = img.clone()
if self.augment: if self.augment:
face = K.RandomHorizontalFlip()(face) face = K.RandomHorizontalFlip()(face)
other[:,:,190:440,x_offset:(512-x_offset)] *= 0 other[:, :, 190:440, x_offset:(512 - x_offset)] *= 0
encodings = [ encodings = [
self.encoder.encode(face), self.encoder.encode(face),
self.encoder.encode(other), self.encoder.encode(other),
@@ -55,26 +72,32 @@ class FaceClipEncoder(AbstractEncoder):
def encode(self, img): def encode(self, img):
if isinstance(img, list): if isinstance(img, list):
# Uncondition # Uncondition
return torch.zeros((1, 2, 768), device=self.encoder.model.visual.conv1.weight.device) return torch.zeros(
(1, 2, 768),
device=self.encoder.model.visual.conv1.weight.device)
return self(img) return self(img)
class FaceIdClipEncoder(AbstractEncoder): class FaceIdClipEncoder(AbstractEncoder):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.encoder = FrozenCLIPImageEmbedder() self.encoder = FrozenCLIPImageEmbedder()
for p in self.encoder.parameters(): for p in self.encoder.parameters():
p.requires_grad = False p.requires_grad = False
self.id = FrozenFaceEncoder("/home/jpinkney/code/stable-diffusion/model_ir_se50.pth", augment=True) self.id = FrozenFaceEncoder(
'/home/jpinkney/code/stable-diffusion/model_ir_se50.pth',
augment=True)
def forward(self, img): def forward(self, img):
encodings = [] encodings = []
with torch.no_grad(): with torch.no_grad():
face = kornia.geometry.resize(img, (256, 256), face = kornia.geometry.resize(
interpolation='bilinear', align_corners=True) img, (256, 256), interpolation='bilinear', align_corners=True)
other = img.clone() other = img.clone()
other[:,:,184:452,122:396] *= 0 other[:, :, 184:452, 122:396] *= 0
encodings = [ encodings = [
self.id.encode(face), self.id.encode(face),
self.encoder.encode(other), self.encoder.encode(other),
@@ -85,11 +108,15 @@ class FaceIdClipEncoder(AbstractEncoder):
def encode(self, img): def encode(self, img):
if isinstance(img, list): if isinstance(img, list):
# Uncondition # Uncondition
return torch.zeros((1, 2, 768), device=self.encoder.model.visual.conv1.weight.device) return torch.zeros(
(1, 2, 768),
device=self.encoder.model.visual.conv1.weight.device)
return self(img) return self(img)
class ClassEmbedder(nn.Module): class ClassEmbedder(nn.Module):
def __init__(self, embed_dim, n_classes=1000, key='class'): def __init__(self, embed_dim, n_classes=1000, key='class'):
super().__init__() super().__init__()
self.key = key self.key = key
@@ -106,11 +133,19 @@ class ClassEmbedder(nn.Module):
class TransformerEmbedder(AbstractEncoder): class TransformerEmbedder(AbstractEncoder):
"""Some transformer encoder layers""" """Some transformer encoder layers"""
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
def __init__(self,
n_embed,
n_layer,
vocab_size,
max_seq_len=77,
device='cuda'):
super().__init__() super().__init__()
self.device = device self.device = device
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, self.transformer = TransformerWrapper(
attn_layers=Encoder(dim=n_embed, depth=n_layer)) num_tokens=vocab_size,
max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer))
def forward(self, tokens): def forward(self, tokens):
tokens = tokens.to(self.device) # meh tokens = tokens.to(self.device) # meh
@@ -123,18 +158,25 @@ class TransformerEmbedder(AbstractEncoder):
class BERTTokenizer(AbstractEncoder): class BERTTokenizer(AbstractEncoder):
""" Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
def __init__(self, device="cuda", vq_interface=True, max_length=77):
def __init__(self, device='cuda', vq_interface=True, max_length=77):
super().__init__() super().__init__()
from transformers import BertTokenizerFast # TODO: add to reuquirements from transformers import BertTokenizerFast # TODO: add to reuquirements
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
self.device = device self.device = device
self.vq_interface = vq_interface self.vq_interface = vq_interface
self.max_length = max_length self.max_length = max_length
def forward(self, text): def forward(self, text):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, batch_encoding = self.tokenizer(
return_overflowing_tokens=False, padding="max_length", return_tensors="pt") text,
tokens = batch_encoding["input_ids"].to(self.device) truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding='max_length',
return_tensors='pt')
tokens = batch_encoding['input_ids'].to(self.device)
return tokens return tokens
@torch.no_grad() @torch.no_grad()
@@ -150,20 +192,30 @@ class BERTTokenizer(AbstractEncoder):
class BERTEmbedder(AbstractEncoder): class BERTEmbedder(AbstractEncoder):
"""Uses the BERT tokenizr model and add some transformer encoder layers""" """Uses the BERT tokenizr model and add some transformer encoder layers"""
def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
device="cuda",use_tokenizer=True, embedding_dropout=0.0): def __init__(self,
n_embed,
n_layer,
vocab_size=30522,
max_seq_len=77,
device='cuda',
use_tokenizer=True,
embedding_dropout=0.0):
super().__init__() super().__init__()
self.use_tknz_fn = use_tokenizer self.use_tknz_fn = use_tokenizer
if self.use_tknz_fn: if self.use_tknz_fn:
self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) self.tknz_fn = BERTTokenizer(
vq_interface=False, max_length=max_seq_len)
self.device = device self.device = device
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, self.transformer = TransformerWrapper(
attn_layers=Encoder(dim=n_embed, depth=n_layer), num_tokens=vocab_size,
emb_dropout=embedding_dropout) max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer),
emb_dropout=embedding_dropout)
def forward(self, text): def forward(self, text):
if self.use_tknz_fn: if self.use_tknz_fn:
tokens = self.tknz_fn(text)#.to(self.device) tokens = self.tknz_fn(text) # .to(self.device)
else: else:
tokens = text tokens = text
z = self.transformer(tokens, return_embeddings=True) z = self.transformer(tokens, return_embeddings=True)
@@ -174,8 +226,6 @@ class BERTEmbedder(AbstractEncoder):
return self(text) return self(text)
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
def disabled_train(self, mode=True): def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode """Overwrite model.train with this function to make sure train/eval mode
does not change anymore.""" does not change anymore."""
@@ -184,24 +234,41 @@ def disabled_train(self, mode=True):
class FrozenT5Embedder(AbstractEncoder): class FrozenT5Embedder(AbstractEncoder):
"""Uses the T5 transformer encoder for text""" """Uses the T5 transformer encoder for text"""
def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
def __init__(self,
version='google/t5-v1_1-large',
device='cuda',
max_length=77
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
super().__init__() super().__init__()
self.tokenizer = T5Tokenizer.from_pretrained(version, cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models') self.tokenizer = T5Tokenizer.from_pretrained(
self.transformer = T5EncoderModel.from_pretrained(version, cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models') version,
cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models'
)
self.transformer = T5EncoderModel.from_pretrained(
version,
cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models'
)
self.device = device self.device = device
self.max_length = max_length # TODO: typical value? self.max_length = max_length # TODO: typical value?
self.freeze() self.freeze()
def freeze(self): def freeze(self):
self.transformer = self.transformer.eval() self.transformer = self.transformer.eval()
#self.train = disabled_train # self.train = disabled_train
for param in self.parameters(): for param in self.parameters():
param.requires_grad = False param.requires_grad = False
def forward(self, text): def forward(self, text):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, batch_encoding = self.tokenizer(
return_overflowing_tokens=False, padding="max_length", return_tensors="pt") text,
tokens = batch_encoding["input_ids"].to(self.device) truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding='max_length',
return_tensors='pt')
tokens = batch_encoding['input_ids'].to(self.device)
outputs = self.transformer(input_ids=tokens) outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state z = outputs.last_hidden_state
@@ -210,10 +277,9 @@ class FrozenT5Embedder(AbstractEncoder):
def encode(self, text): def encode(self, text):
return self(text) return self(text)
from modelscope.models.cv.image_to_3d.ldm.thirdp.psp.id_loss import IDFeatures
import kornia.augmentation as K
class FrozenFaceEncoder(AbstractEncoder): class FrozenFaceEncoder(AbstractEncoder):
def __init__(self, model_path, augment=False): def __init__(self, model_path, augment=False):
super().__init__() super().__init__()
self.loss_fn = IDFeatures(model_path) self.loss_fn = IDFeatures(model_path)
@@ -242,8 +308,8 @@ class FrozenFaceEncoder(AbstractEncoder):
if self.augment is not None: if self.augment is not None:
# Transforms require 0-1 # Transforms require 0-1
img = self.augment((img + 1)/2) img = self.augment((img + 1) / 2)
img = 2*img - 1 img = 2 * img - 1
feat = self.loss_fn(img, crop=True) feat = self.loss_fn(img, crop=True)
feat = self.mapper(feat.unsqueeze(1)) feat = self.mapper(feat.unsqueeze(1))
@@ -252,26 +318,43 @@ class FrozenFaceEncoder(AbstractEncoder):
def encode(self, img): def encode(self, img):
return self(img) return self(img)
class FrozenCLIPEmbedder(AbstractEncoder): class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)""" """Uses the CLIP transformer encoder for text (from huggingface)"""
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): # clip-vit-base-patch32
def __init__(self,
version='openai/clip-vit-large-patch14',
device='cuda',
max_length=77): # clip-vit-base-patch32
super().__init__() super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version, cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models') self.tokenizer = CLIPTokenizer.from_pretrained(
self.transformer = CLIPTextModel.from_pretrained(version, cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models') version,
cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models'
)
self.transformer = CLIPTextModel.from_pretrained(
version,
cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models'
)
self.device = device self.device = device
self.max_length = max_length # TODO: typical value? self.max_length = max_length # TODO: typical value?
self.freeze() self.freeze()
def freeze(self): def freeze(self):
self.transformer = self.transformer.eval() self.transformer = self.transformer.eval()
#self.train = disabled_train # self.train = disabled_train
for param in self.parameters(): for param in self.parameters():
param.requires_grad = False param.requires_grad = False
def forward(self, text): def forward(self, text):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, batch_encoding = self.tokenizer(
return_overflowing_tokens=False, padding="max_length", return_tensors="pt") text,
tokens = batch_encoding["input_ids"].to(self.device) truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding='max_length',
return_tensors='pt')
tokens = batch_encoding['input_ids'].to(self.device)
outputs = self.transformer(input_ids=tokens) outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state z = outputs.last_hidden_state
@@ -280,36 +363,47 @@ class FrozenCLIPEmbedder(AbstractEncoder):
def encode(self, text): def encode(self, text):
return self(text) return self(text)
import torch.nn.functional as F
from transformers import CLIPVisionModel
class ClipImageProjector(AbstractEncoder): class ClipImageProjector(AbstractEncoder):
""" """
Uses the CLIP image encoder. Uses the CLIP image encoder.
""" """
def __init__(self, version="openai/clip-vit-large-patch14", max_length=77): # clip-vit-base-patch32
def __init__(self,
version='openai/clip-vit-large-patch14',
max_length=77): # clip-vit-base-patch32
super().__init__() super().__init__()
self.model = CLIPVisionModel.from_pretrained(version) self.model = CLIPVisionModel.from_pretrained(version)
self.model.train() self.model.train()
self.max_length = max_length # TODO: typical value? self.max_length = max_length # TODO: typical value?
self.antialias = True self.antialias = True
self.mapper = torch.nn.Linear(1024, 768) self.mapper = torch.nn.Linear(1024, 768)
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) self.register_buffer(
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 'mean',
torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
persistent=False)
self.register_buffer(
'std',
torch.Tensor([0.26862954, 0.26130258, 0.27577711]),
persistent=False)
null_cond = self.get_null_cond(version, max_length) null_cond = self.get_null_cond(version, max_length)
self.register_buffer('null_cond', null_cond) self.register_buffer('null_cond', null_cond)
@torch.no_grad() @torch.no_grad()
def get_null_cond(self, version, max_length): def get_null_cond(self, version, max_length):
device = self.mean.device device = self.mean.device
embedder = FrozenCLIPEmbedder(version=version, device=device, max_length=max_length) embedder = FrozenCLIPEmbedder(
null_cond = embedder([""]) version=version, device=device, max_length=max_length)
null_cond = embedder([''])
return null_cond return null_cond
def preprocess(self, x): def preprocess(self, x):
# Expects inputs in the range -1, 1 # Expects inputs in the range -1, 1
x = kornia.geometry.resize(x, (224, 224), x = kornia.geometry.resize(
interpolation='bicubic',align_corners=True, x, (224, 224),
antialias=self.antialias) interpolation='bicubic',
align_corners=True,
antialias=self.antialias)
x = (x + 1.) / 2. x = (x + 1.) / 2.
# renormalize according to clip # renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std) x = kornia.enhance.normalize(x, self.mean, self.std)
@@ -323,15 +417,23 @@ class ClipImageProjector(AbstractEncoder):
outputs = self.model(pixel_values=x) outputs = self.model(pixel_values=x)
last_hidden_state = outputs.last_hidden_state last_hidden_state = outputs.last_hidden_state
last_hidden_state = self.mapper(last_hidden_state) last_hidden_state = self.mapper(last_hidden_state)
return F.pad(last_hidden_state, [0,0, 0,self.max_length-last_hidden_state.shape[1], 0,0]) return F.pad(
last_hidden_state,
[0, 0, 0, self.max_length - last_hidden_state.shape[1], 0, 0])
def encode(self, im): def encode(self, im):
return self(im) return self(im)
class ProjectedFrozenCLIPEmbedder(AbstractEncoder): class ProjectedFrozenCLIPEmbedder(AbstractEncoder):
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): # clip-vit-base-patch32
def __init__(self,
version='openai/clip-vit-large-patch14',
device='cuda',
max_length=77): # clip-vit-base-patch32
super().__init__() super().__init__()
self.embedder = FrozenCLIPEmbedder(version=version, device=device, max_length=max_length) self.embedder = FrozenCLIPEmbedder(
version=version, device=device, max_length=max_length)
self.projection = torch.nn.Linear(768, 768) self.projection = torch.nn.Linear(768, 768)
def forward(self, text): def forward(self, text):
@@ -341,31 +443,41 @@ class ProjectedFrozenCLIPEmbedder(AbstractEncoder):
def encode(self, text): def encode(self, text):
return self(text) return self(text)
class FrozenCLIPImageEmbedder(AbstractEncoder): class FrozenCLIPImageEmbedder(AbstractEncoder):
""" """
Uses the CLIP image encoder. Uses the CLIP image encoder.
Not actually frozen... If you want that set cond_stage_trainable=False in cfg Not actually frozen... If you want that set cond_stage_trainable=False in cfg
""" """
def __init__( def __init__(
self, self,
model='ViT-L/14', model='ViT-L/14',
jit=False, jit=False,
device='cpu', device='cpu',
antialias=False, antialias=False,
): ):
super().__init__() super().__init__()
self.model, _ = clip.load(name=model, device=device, jit=jit) self.model, _ = clip.load(name=model, device=device, jit=jit)
# We don't use the text part so delete it # We don't use the text part so delete it
del self.model.transformer del self.model.transformer
self.antialias = antialias self.antialias = antialias
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) self.register_buffer(
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 'mean',
torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
persistent=False)
self.register_buffer(
'std',
torch.Tensor([0.26862954, 0.26130258, 0.27577711]),
persistent=False)
def preprocess(self, x): def preprocess(self, x):
# Expects inputs in the range -1, 1 # Expects inputs in the range -1, 1
x = kornia.geometry.resize(x, (224, 224), x = kornia.geometry.resize(
interpolation='bicubic',align_corners=True, x, (224, 224),
antialias=self.antialias) interpolation='bicubic',
align_corners=True,
antialias=self.antialias)
x = (x + 1.) / 2. x = (x + 1.) / 2.
# renormalize according to clip # renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std) x = kornia.enhance.normalize(x, self.mean, self.std)
@@ -382,35 +494,41 @@ class FrozenCLIPImageEmbedder(AbstractEncoder):
def encode(self, im): def encode(self, im):
return self(im).unsqueeze(1) return self(im).unsqueeze(1)
from torchvision import transforms
import random
class FrozenCLIPImageMutliEmbedder(AbstractEncoder): class FrozenCLIPImageMutliEmbedder(AbstractEncoder):
""" """
Uses the CLIP image encoder. Uses the CLIP image encoder.
Not actually frozen... If you want that set cond_stage_trainable=False in cfg Not actually frozen... If you want that set cond_stage_trainable=False in cfg
""" """
def __init__( def __init__(
self, self,
model='ViT-L/14', model='ViT-L/14',
jit=False, jit=False,
device='cpu', device='cpu',
antialias=True, antialias=True,
max_crops=5, max_crops=5,
): ):
super().__init__() super().__init__()
self.model, _ = clip.load(name=model, device=device, jit=jit) self.model, _ = clip.load(name=model, device=device, jit=jit)
# We don't use the text part so delete it # We don't use the text part so delete it
del self.model.transformer del self.model.transformer
self.antialias = antialias self.antialias = antialias
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) self.register_buffer(
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 'mean',
torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
persistent=False)
self.register_buffer(
'std',
torch.Tensor([0.26862954, 0.26130258, 0.27577711]),
persistent=False)
self.max_crops = max_crops self.max_crops = max_crops
def preprocess(self, x): def preprocess(self, x):
# Expects inputs in the range -1, 1 # Expects inputs in the range -1, 1
randcrop = transforms.RandomResizedCrop(224, scale=(0.085, 1.0), ratio=(1,1)) randcrop = transforms.RandomResizedCrop(
224, scale=(0.085, 1.0), ratio=(1, 1))
max_crops = self.max_crops max_crops = self.max_crops
patches = [] patches = []
crops = [randcrop(x) for _ in range(max_crops)] crops = [randcrop(x) for _ in range(max_crops)]
@@ -441,7 +559,9 @@ class FrozenCLIPImageMutliEmbedder(AbstractEncoder):
def encode(self, im): def encode(self, im):
return self(im) return self(im)
class SpatialRescaler(nn.Module): class SpatialRescaler(nn.Module):
def __init__(self, def __init__(self,
n_stages=1, n_stages=1,
method='bilinear', method='bilinear',
@@ -452,19 +572,24 @@ class SpatialRescaler(nn.Module):
super().__init__() super().__init__()
self.n_stages = n_stages self.n_stages = n_stages
assert self.n_stages >= 0 assert self.n_stages >= 0
assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] assert method in [
'nearest', 'linear', 'bilinear', 'trilinear', 'bicubic', 'area'
]
self.multiplier = multiplier self.multiplier = multiplier
self.interpolator = partial(torch.nn.functional.interpolate, mode=method) self.interpolator = partial(
torch.nn.functional.interpolate, mode=method)
self.remap_output = out_channels is not None self.remap_output = out_channels is not None
if self.remap_output: if self.remap_output:
print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') print(
self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.'
)
self.channel_mapper = nn.Conv2d(
in_channels, out_channels, 1, bias=bias)
def forward(self,x): def forward(self, x):
for stage in range(self.n_stages): for stage in range(self.n_stages):
x = self.interpolator(x, scale_factor=self.multiplier) x = self.interpolator(x, scale_factor=self.multiplier)
if self.remap_output: if self.remap_output:
x = self.channel_mapper(x) x = self.channel_mapper(x)
return x return x
@@ -473,25 +598,38 @@ class SpatialRescaler(nn.Module):
return self(x) return self(x)
from modelscope.models.cv.image_to_3d.ldm.util import instantiate_from_config
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
class LowScaleEncoder(nn.Module): class LowScaleEncoder(nn.Module):
def __init__(self, model_config, linear_start, linear_end, timesteps=1000, max_noise_level=250, output_size=64,
def __init__(self,
model_config,
linear_start,
linear_end,
timesteps=1000,
max_noise_level=250,
output_size=64,
scale_factor=1.0): scale_factor=1.0):
super().__init__() super().__init__()
self.max_noise_level = max_noise_level self.max_noise_level = max_noise_level
self.model = instantiate_from_config(model_config) self.model = instantiate_from_config(model_config)
self.augmentation_schedule = self.register_schedule(timesteps=timesteps, linear_start=linear_start, self.augmentation_schedule = self.register_schedule(
linear_end=linear_end) timesteps=timesteps,
linear_start=linear_start,
linear_end=linear_end)
self.out_size = output_size self.out_size = output_size
self.scale_factor = scale_factor self.scale_factor = scale_factor
def register_schedule(self, beta_schedule="linear", timesteps=1000, def register_schedule(self,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): beta_schedule='linear',
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, timesteps=1000,
cosine_s=cosine_s) linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3):
betas = make_beta_schedule(
beta_schedule,
timesteps,
linear_start=linear_start,
linear_end=linear_end,
cosine_s=cosine_s)
alphas = 1. - betas alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0) alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
@@ -500,33 +638,45 @@ class LowScaleEncoder(nn.Module):
self.num_timesteps = int(timesteps) self.num_timesteps = int(timesteps)
self.linear_start = linear_start self.linear_start = linear_start
self.linear_end = linear_end self.linear_end = linear_end
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' assert alphas_cumprod.shape[
0] == self.num_timesteps, 'alphas have to be defined for each timestep'
to_torch = partial(torch.tensor, dtype=torch.float32) to_torch = partial(torch.tensor, dtype=torch.float32)
self.register_buffer('betas', to_torch(betas)) self.register_buffer('betas', to_torch(betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) self.register_buffer('alphas_cumprod_prev',
to_torch(alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others # calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) self.register_buffer('sqrt_alphas_cumprod',
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) to_torch(np.sqrt(alphas_cumprod)))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) self.register_buffer('sqrt_one_minus_alphas_cumprod',
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) to_torch(np.sqrt(1. - alphas_cumprod)))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) self.register_buffer('log_one_minus_alphas_cumprod',
to_torch(np.log(1. - alphas_cumprod)))
self.register_buffer('sqrt_recip_alphas_cumprod',
to_torch(np.sqrt(1. / alphas_cumprod)))
self.register_buffer('sqrt_recipm1_alphas_cumprod',
to_torch(np.sqrt(1. / alphas_cumprod - 1)))
def q_sample(self, x_start, t, noise=None): def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start)) noise = default(noise, lambda: torch.randn_like(x_start))
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape)
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) * x_start
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t,
x_start.shape) * noise)
def forward(self, x): def forward(self, x):
z = self.model.encode(x).sample() z = self.model.encode(x).sample()
z = z * self.scale_factor z = z * self.scale_factor
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() noise_level = torch.randint(
0, self.max_noise_level, (x.shape[0], ), device=x.device).long()
z = self.q_sample(z, noise_level) z = self.q_sample(z, noise_level)
if self.out_size is not None: if self.out_size is not None:
z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest") # TODO: experiment with mode z = torch.nn.functional.interpolate(
z, size=self.out_size,
mode='nearest') # TODO: experiment with mode
# z = z.repeat_interleave(2, -2).repeat_interleave(2, -1) # z = z.repeat_interleave(2, -2).repeat_interleave(2, -1)
return z, noise_level return z, noise_level
@@ -535,10 +685,13 @@ class LowScaleEncoder(nn.Module):
return self.model.decode(z) return self.model.decode(z)
if __name__ == "__main__": if __name__ == '__main__':
from ldm.util import count_params from ldm.util import count_params
sentences = ["a hedgehog drinking a whiskey", "der mond ist aufgegangen", "Ein Satz mit vielen Sonderzeichen: äöü ß ?! : 'xx-y/@s'"] sentences = [
model = FrozenT5Embedder(version="google/t5-v1_1-xl").cuda() 'a hedgehog drinking a whiskey', 'der mond ist aufgegangen',
"Ein Satz mit vielen Sonderzeichen: äöü ß ?! : 'xx-y/@s'"
]
model = FrozenT5Embedder(version='google/t5-v1_1-xl').cuda()
count_params(model, True) count_params(model, True)
z = model(sentences) z = model(sentences)
print(z.shape) print(z.shape)
@@ -548,4 +701,4 @@ if __name__ == "__main__":
z = model(sentences) z = model(sentences)
print(z.shape) print(z.shape)
print("done.") print('done.')

View File

@@ -1,28 +1,26 @@
"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers""" """shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers"""
import torch from collections import namedtuple
from torch import nn, einsum
import torch.nn.functional as F
from functools import partial from functools import partial
from inspect import isfunction from inspect import isfunction
from collections import namedtuple
from einops import rearrange, repeat, reduce import torch
import torch.nn.functional as F
from einops import rearrange, reduce, repeat
from torch import einsum, nn
# constants # constants
DEFAULT_DIM_HEAD = 64 DEFAULT_DIM_HEAD = 64
Intermediates = namedtuple('Intermediates', [ Intermediates = namedtuple('Intermediates',
'pre_softmax_attn', ['pre_softmax_attn', 'post_softmax_attn'])
'post_softmax_attn'
])
LayerIntermediates = namedtuple('Intermediates', [ LayerIntermediates = namedtuple('Intermediates',
'hiddens', ['hiddens', 'attn_intermediates'])
'attn_intermediates'
])
class AbsolutePositionalEmbedding(nn.Module): class AbsolutePositionalEmbedding(nn.Module):
def __init__(self, dim, max_seq_len): def __init__(self, dim, max_seq_len):
super().__init__() super().__init__()
self.emb = nn.Embedding(max_seq_len, dim) self.emb = nn.Embedding(max_seq_len, dim)
@@ -37,13 +35,15 @@ class AbsolutePositionalEmbedding(nn.Module):
class FixedPositionalEmbedding(nn.Module): class FixedPositionalEmbedding(nn.Module):
def __init__(self, dim): def __init__(self, dim):
super().__init__() super().__init__()
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) inv_freq = 1. / (10000**(torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq) self.register_buffer('inv_freq', inv_freq)
def forward(self, x, seq_dim=1, offset=0): def forward(self, x, seq_dim=1, offset=0):
t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset t = torch.arange(
x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
return emb[None, :, :] return emb[None, :, :]
@@ -51,6 +51,7 @@ class FixedPositionalEmbedding(nn.Module):
# helpers # helpers
def exists(val): def exists(val):
return val is not None return val is not None
@@ -62,20 +63,26 @@ def default(val, d):
def always(val): def always(val):
def inner(*args, **kwargs): def inner(*args, **kwargs):
return val return val
return inner return inner
def not_equals(val): def not_equals(val):
def inner(x): def inner(x):
return x != val return x != val
return inner return inner
def equals(val): def equals(val):
def inner(x): def inner(x):
return x == val return x == val
return inner return inner
@@ -85,6 +92,7 @@ def max_neg_value(tensor):
# keyword argument helpers # keyword argument helpers
def pick_and_pop(keys, d): def pick_and_pop(keys, d):
values = list(map(lambda key: d.pop(key), keys)) values = list(map(lambda key: d.pop(key), keys))
return dict(zip(keys, values)) return dict(zip(keys, values))
@@ -96,7 +104,7 @@ def group_dict_by_key(cond, d):
match = bool(cond(key)) match = bool(cond(key))
ind = int(not match) ind = int(not match)
return_val[ind][key] = d[key] return_val[ind][key] = d[key]
return (*return_val,) return (*return_val, )
def string_begins_with(prefix, str): def string_begins_with(prefix, str):
@@ -108,13 +116,17 @@ def group_by_key_prefix(prefix, d):
def groupby_prefix_and_trim(prefix, d): def groupby_prefix_and_trim(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) kwargs_with_prefix, kwargs = group_dict_by_key(
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) partial(string_begins_with, prefix), d)
kwargs_without_prefix = dict(
map(lambda x: (x[0][len(prefix):], x[1]),
tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs return kwargs_without_prefix, kwargs
# classes # classes
class Scale(nn.Module): class Scale(nn.Module):
def __init__(self, value, fn): def __init__(self, value, fn):
super().__init__() super().__init__()
self.value = value self.value = value
@@ -126,6 +138,7 @@ class Scale(nn.Module):
class Rezero(nn.Module): class Rezero(nn.Module):
def __init__(self, fn): def __init__(self, fn):
super().__init__() super().__init__()
self.fn = fn self.fn = fn
@@ -137,9 +150,10 @@ class Rezero(nn.Module):
class ScaleNorm(nn.Module): class ScaleNorm(nn.Module):
def __init__(self, dim, eps=1e-5): def __init__(self, dim, eps=1e-5):
super().__init__() super().__init__()
self.scale = dim ** -0.5 self.scale = dim**-0.5
self.eps = eps self.eps = eps
self.g = nn.Parameter(torch.ones(1)) self.g = nn.Parameter(torch.ones(1))
@@ -149,9 +163,10 @@ class ScaleNorm(nn.Module):
class RMSNorm(nn.Module): class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-8): def __init__(self, dim, eps=1e-8):
super().__init__() super().__init__()
self.scale = dim ** -0.5 self.scale = dim**-0.5
self.eps = eps self.eps = eps
self.g = nn.Parameter(torch.ones(dim)) self.g = nn.Parameter(torch.ones(dim))
@@ -161,11 +176,13 @@ class RMSNorm(nn.Module):
class Residual(nn.Module): class Residual(nn.Module):
def forward(self, x, residual): def forward(self, x, residual):
return x + residual return x + residual
class GRUGating(nn.Module): class GRUGating(nn.Module):
def __init__(self, dim): def __init__(self, dim):
super().__init__() super().__init__()
self.gru = nn.GRUCell(dim, dim) self.gru = nn.GRUCell(dim, dim)
@@ -173,15 +190,16 @@ class GRUGating(nn.Module):
def forward(self, x, residual): def forward(self, x, residual):
gated_output = self.gru( gated_output = self.gru(
rearrange(x, 'b n d -> (b n) d'), rearrange(x, 'b n d -> (b n) d'),
rearrange(residual, 'b n d -> (b n) d') rearrange(residual, 'b n d -> (b n) d'))
)
return gated_output.reshape_as(x) return gated_output.reshape_as(x)
# feedforward # feedforward
class GEGLU(nn.Module): class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out): def __init__(self, dim_in, dim_out):
super().__init__() super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2) self.proj = nn.Linear(dim_in, dim_out * 2)
@@ -192,20 +210,16 @@ class GEGLU(nn.Module):
class FeedForward(nn.Module): class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__() super().__init__()
inner_dim = int(dim * mult) inner_dim = int(dim * mult)
dim_out = default(dim_out, dim) dim_out = default(dim_out, dim)
project_in = nn.Sequential( project_in = nn.Sequential(nn.Linear(
nn.Linear(dim, inner_dim), dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential( self.net = nn.Sequential(project_in, nn.Dropout(dropout),
project_in, nn.Linear(inner_dim, dim_out))
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
)
def forward(self, x): def forward(self, x):
return self.net(x) return self.net(x)
@@ -213,24 +227,24 @@ class FeedForward(nn.Module):
# attention. # attention.
class Attention(nn.Module): class Attention(nn.Module):
def __init__(
self, def __init__(self,
dim, dim,
dim_head=DEFAULT_DIM_HEAD, dim_head=DEFAULT_DIM_HEAD,
heads=8, heads=8,
causal=False, causal=False,
mask=None, mask=None,
talking_heads=False, talking_heads=False,
sparse_topk=None, sparse_topk=None,
use_entmax15=False, use_entmax15=False,
num_mem_kv=0, num_mem_kv=0,
dropout=0., dropout=0.,
on_attn=False on_attn=False):
):
super().__init__() super().__init__()
if use_entmax15: if use_entmax15:
raise NotImplementedError("Check out entmax activation instead of softmax activation!") raise NotImplementedError(
self.scale = dim_head ** -0.5 'Check out entmax activation instead of softmax activation!')
self.scale = dim_head**-0.5
self.heads = heads self.heads = heads
self.causal = causal self.causal = causal
self.mask = mask self.mask = mask
@@ -252,7 +266,7 @@ class Attention(nn.Module):
self.sparse_topk = sparse_topk self.sparse_topk = sparse_topk
# entmax # entmax
#self.attn_fn = entmax15 if use_entmax15 else F.softmax # self.attn_fn = entmax15 if use_entmax15 else F.softmax
self.attn_fn = F.softmax self.attn_fn = F.softmax
# add memory key / values # add memory key / values
@@ -263,19 +277,19 @@ class Attention(nn.Module):
# attention on attention # attention on attention
self.attn_on_attn = on_attn self.attn_on_attn = on_attn
self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim) self.to_out = nn.Sequential(nn.Linear(
inner_dim, dim
* 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim)
def forward( def forward(self,
self, x,
x, context=None,
context=None, mask=None,
mask=None, context_mask=None,
context_mask=None, rel_pos=None,
rel_pos=None, sinusoidal_emb=None,
sinusoidal_emb=None, prev_attn=None,
prev_attn=None, mem=None):
mem=None
):
b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device
kv_input = default(context, x) kv_input = default(context, x)
@@ -297,23 +311,29 @@ class Attention(nn.Module):
k = self.to_k(k_input) k = self.to_k(k_input)
v = self.to_v(v_input) v = self.to_v(v_input)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h),
(q, k, v))
input_mask = None input_mask = None
if any(map(exists, (mask, context_mask))): if any(map(exists, (mask, context_mask))):
q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) q_mask = default(mask, lambda: torch.ones(
(b, n), device=device).bool())
k_mask = q_mask if not exists(context) else context_mask k_mask = q_mask if not exists(context) else context_mask
k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) k_mask = default(
k_mask, lambda: torch.ones(
(b, k.shape[-2]), device=device).bool())
q_mask = rearrange(q_mask, 'b i -> b () i ()') q_mask = rearrange(q_mask, 'b i -> b () i ()')
k_mask = rearrange(k_mask, 'b j -> b () () j') k_mask = rearrange(k_mask, 'b j -> b () () j')
input_mask = q_mask * k_mask input_mask = q_mask * k_mask
if self.num_mem_kv > 0: if self.num_mem_kv > 0:
mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v)) mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b),
(self.mem_k, self.mem_v))
k = torch.cat((mem_k, k), dim=-2) k = torch.cat((mem_k, k), dim=-2)
v = torch.cat((mem_v, v), dim=-2) v = torch.cat((mem_v, v), dim=-2)
if exists(input_mask): if exists(input_mask):
input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) input_mask = F.pad(
input_mask, (self.num_mem_kv, 0), value=True)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
mask_value = max_neg_value(dots) mask_value = max_neg_value(dots)
@@ -324,7 +344,8 @@ class Attention(nn.Module):
pre_softmax_attn = dots pre_softmax_attn = dots
if talking_heads: if talking_heads:
dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous() dots = einsum('b h i j, h k -> b k i j', dots,
self.pre_softmax_proj).contiguous()
if exists(rel_pos): if exists(rel_pos):
dots = rel_pos(dots) dots = rel_pos(dots)
@@ -336,7 +357,8 @@ class Attention(nn.Module):
if self.causal: if self.causal:
i, j = dots.shape[-2:] i, j = dots.shape[-2:]
r = torch.arange(i, device=device) r = torch.arange(i, device=device)
mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j') mask = rearrange(r, 'i -> () () i ()') < rearrange(
r, 'j -> () () () j')
mask = F.pad(mask, (j - i, 0), value=False) mask = F.pad(mask, (j - i, 0), value=False)
dots.masked_fill_(mask, mask_value) dots.masked_fill_(mask, mask_value)
del mask del mask
@@ -354,59 +376,60 @@ class Attention(nn.Module):
attn = self.dropout(attn) attn = self.dropout(attn)
if talking_heads: if talking_heads:
attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous() attn = einsum('b h i j, h k -> b k i j', attn,
self.post_softmax_proj).contiguous()
out = einsum('b h i j, b h j d -> b h i d', attn, v) out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)') out = rearrange(out, 'b h n d -> b n (h d)')
intermediates = Intermediates( intermediates = Intermediates(
pre_softmax_attn=pre_softmax_attn, pre_softmax_attn=pre_softmax_attn,
post_softmax_attn=post_softmax_attn post_softmax_attn=post_softmax_attn)
)
return self.to_out(out), intermediates return self.to_out(out), intermediates
class AttentionLayers(nn.Module): class AttentionLayers(nn.Module):
def __init__(
self, def __init__(self,
dim, dim,
depth, depth,
heads=8, heads=8,
causal=False, causal=False,
cross_attend=False, cross_attend=False,
only_cross=False, only_cross=False,
use_scalenorm=False, use_scalenorm=False,
use_rmsnorm=False, use_rmsnorm=False,
use_rezero=False, use_rezero=False,
rel_pos_num_buckets=32, rel_pos_num_buckets=32,
rel_pos_max_distance=128, rel_pos_max_distance=128,
position_infused_attn=False, position_infused_attn=False,
custom_layers=None, custom_layers=None,
sandwich_coef=None, sandwich_coef=None,
par_ratio=None, par_ratio=None,
residual_attn=False, residual_attn=False,
cross_residual_attn=False, cross_residual_attn=False,
macaron=False, macaron=False,
pre_norm=True, pre_norm=True,
gate_residual=False, gate_residual=False,
**kwargs **kwargs):
):
super().__init__() super().__init__()
ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs) attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)
dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) # dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
self.dim = dim self.dim = dim
self.depth = depth self.depth = depth
self.layers = nn.ModuleList([]) self.layers = nn.ModuleList([])
self.has_pos_emb = position_infused_attn self.has_pos_emb = position_infused_attn
self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None self.pia_pos_emb = FixedPositionalEmbedding(
dim) if position_infused_attn else None
self.rotary_pos_emb = always(None) self.rotary_pos_emb = always(None)
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' assert rel_pos_num_buckets <= rel_pos_max_distance, \
'number of relative position buckets must be less than the relative position max distance'
self.rel_pos = None self.rel_pos = None
self.pre_norm = pre_norm self.pre_norm = pre_norm
@@ -429,7 +452,7 @@ class AttentionLayers(nn.Module):
default_block = ('a', 'f') default_block = ('a', 'f')
if macaron: if macaron:
default_block = ('f',) + default_block default_block = ('f', ) + default_block
if exists(custom_layers): if exists(custom_layers):
layer_types = custom_layers layer_types = custom_layers
@@ -440,13 +463,17 @@ class AttentionLayers(nn.Module):
par_attn = par_depth // par_ratio par_attn = par_depth // par_ratio
depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
par_width = (depth_cut + depth_cut // par_attn) // par_attn par_width = (depth_cut + depth_cut // par_attn) // par_attn
assert len(default_block) <= par_width, 'default block is too large for par_ratio' assert len(
par_block = default_block + ('f',) * (par_width - len(default_block)) default_block
) <= par_width, 'default block is too large for par_ratio'
par_block = default_block + ('f', ) * (
par_width - len(default_block))
par_head = par_block * par_attn par_head = par_block * par_attn
layer_types = par_head + ('f',) * (par_depth - len(par_head)) layer_types = par_head + ('f', ) * (par_depth - len(par_head))
elif exists(sandwich_coef): elif exists(sandwich_coef):
assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef layer_types = ('a', ) * sandwich_coef + default_block * (
depth - sandwich_coef) + ('f', ) * sandwich_coef
else: else:
layer_types = default_block * depth layer_types = default_block * depth
@@ -455,7 +482,8 @@ class AttentionLayers(nn.Module):
for layer_type in self.layer_types: for layer_type in self.layer_types:
if layer_type == 'a': if layer_type == 'a':
layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) layer = Attention(
dim, heads=heads, causal=causal, **attn_kwargs)
elif layer_type == 'c': elif layer_type == 'c':
layer = Attention(dim, heads=heads, **attn_kwargs) layer = Attention(dim, heads=heads, **attn_kwargs)
elif layer_type == 'f': elif layer_type == 'f':
@@ -472,21 +500,15 @@ class AttentionLayers(nn.Module):
else: else:
residual_fn = Residual() residual_fn = Residual()
self.layers.append(nn.ModuleList([ self.layers.append(nn.ModuleList([norm_fn(), layer, residual_fn]))
norm_fn(),
layer,
residual_fn
]))
def forward( def forward(self,
self, x,
x, context=None,
context=None, mask=None,
mask=None, context_mask=None,
context_mask=None, mems=None,
mems=None, return_hiddens=False):
return_hiddens=False
):
hiddens = [] hiddens = []
intermediates = [] intermediates = []
prev_attn = None prev_attn = None
@@ -494,7 +516,8 @@ class AttentionLayers(nn.Module):
mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): for ind, (layer_type, (norm, block, residual_fn)) in enumerate(
zip(self.layer_types, self.layers)):
is_last = ind == (len(self.layers) - 1) is_last = ind == (len(self.layers) - 1)
if layer_type == 'a': if layer_type == 'a':
@@ -507,10 +530,20 @@ class AttentionLayers(nn.Module):
x = norm(x) x = norm(x)
if layer_type == 'a': if layer_type == 'a':
out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos, out, inter = block(
prev_attn=prev_attn, mem=layer_mem) x,
mask=mask,
sinusoidal_emb=self.pia_pos_emb,
rel_pos=self.rel_pos,
prev_attn=prev_attn,
mem=layer_mem)
elif layer_type == 'c': elif layer_type == 'c':
out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn) out, inter = block(
x,
context=context,
mask=mask,
context_mask=context_mask,
prev_attn=prev_cross_attn)
elif layer_type == 'f': elif layer_type == 'f':
out = block(x) out = block(x)
@@ -529,9 +562,7 @@ class AttentionLayers(nn.Module):
if return_hiddens: if return_hiddens:
intermediates = LayerIntermediates( intermediates = LayerIntermediates(
hiddens=hiddens, hiddens=hiddens, attn_intermediates=intermediates)
attn_intermediates=intermediates
)
return x, intermediates return x, intermediates
@@ -539,28 +570,29 @@ class AttentionLayers(nn.Module):
class Encoder(AttentionLayers): class Encoder(AttentionLayers):
def __init__(self, **kwargs): def __init__(self, **kwargs):
assert 'causal' not in kwargs, 'cannot set causality on encoder' assert 'causal' not in kwargs, 'cannot set causality on encoder'
super().__init__(causal=False, **kwargs) super().__init__(causal=False, **kwargs)
class TransformerWrapper(nn.Module): class TransformerWrapper(nn.Module):
def __init__(
self, def __init__(self,
*, *,
num_tokens, num_tokens,
max_seq_len, max_seq_len,
attn_layers, attn_layers,
emb_dim=None, emb_dim=None,
max_mem_len=0., max_mem_len=0.,
emb_dropout=0., emb_dropout=0.,
num_memory_tokens=None, num_memory_tokens=None,
tie_embedding=False, tie_embedding=False,
use_pos_emb=True use_pos_emb=True):
):
super().__init__() super().__init__()
assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' assert isinstance(
attn_layers, AttentionLayers
), 'attention layers must be one of Encoder or Decoder'
dim = attn_layers.dim dim = attn_layers.dim
emb_dim = default(emb_dim, dim) emb_dim = default(emb_dim, dim)
@@ -571,22 +603,26 @@ class TransformerWrapper(nn.Module):
self.token_emb = nn.Embedding(num_tokens, emb_dim) self.token_emb = nn.Embedding(num_tokens, emb_dim)
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
use_pos_emb and not attn_layers.has_pos_emb) else always(0) use_pos_emb and not attn_layers.has_pos_emb) else always(0)
self.emb_dropout = nn.Dropout(emb_dropout) self.emb_dropout = nn.Dropout(emb_dropout)
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() self.project_emb = nn.Linear(emb_dim,
dim) if emb_dim != dim else nn.Identity()
self.attn_layers = attn_layers self.attn_layers = attn_layers
self.norm = nn.LayerNorm(dim) self.norm = nn.LayerNorm(dim)
self.init_() self.init_()
self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() self.to_logits = nn.Linear(
dim, num_tokens
) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
# memory tokens (like [cls]) from Memory Transformers paper # memory tokens (like [cls]) from Memory Transformers paper
num_memory_tokens = default(num_memory_tokens, 0) num_memory_tokens = default(num_memory_tokens, 0)
self.num_memory_tokens = num_memory_tokens self.num_memory_tokens = num_memory_tokens
if num_memory_tokens > 0: if num_memory_tokens > 0:
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) self.memory_tokens = nn.Parameter(
torch.randn(num_memory_tokens, dim))
# let funnel encoder know number of memory tokens, if specified # let funnel encoder know number of memory tokens, if specified
if hasattr(attn_layers, 'num_memory_tokens'): if hasattr(attn_layers, 'num_memory_tokens'):
@@ -595,17 +631,16 @@ class TransformerWrapper(nn.Module):
def init_(self): def init_(self):
nn.init.normal_(self.token_emb.weight, std=0.02) nn.init.normal_(self.token_emb.weight, std=0.02)
def forward( def forward(self,
self, x,
x, return_embeddings=False,
return_embeddings=False, mask=None,
mask=None, return_mems=False,
return_mems=False, return_attn=False,
return_attn=False, mems=None,
mems=None, **kwargs):
**kwargs # b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
): b, _, num_mem = *x.shape, self.num_memory_tokens
b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
x = self.token_emb(x) x = self.token_emb(x)
x += self.pos_emb(x) x += self.pos_emb(x)
x = self.emb_dropout(x) x = self.emb_dropout(x)
@@ -620,7 +655,8 @@ class TransformerWrapper(nn.Module):
if exists(mask): if exists(mask):
mask = F.pad(mask, (num_mem, 0), value=True) mask = F.pad(mask, (num_mem, 0), value=True)
x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) x, intermediates = self.attn_layers(
x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
x = self.norm(x) x = self.norm(x)
mem, x = x[:, :num_mem], x[:, num_mem:] mem, x = x[:, :num_mem], x[:, num_mem:]
@@ -629,13 +665,18 @@ class TransformerWrapper(nn.Module):
if return_mems: if return_mems:
hiddens = intermediates.hiddens hiddens = intermediates.hiddens
new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens new_mems = list(
new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems)) map(lambda pair: torch.cat(pair, dim=-2), zip(
mems, hiddens))) if exists(mems) else hiddens
new_mems = list(
map(lambda t: t[..., -self.max_mem_len:, :].detach(),
new_mems))
return out, new_mems return out, new_mems
if return_attn: if return_attn:
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) attn_maps = list(
map(lambda t: t.post_softmax_attn,
intermediates.attn_intermediates))
return out, attn_maps return out, attn_maps
return out return out

View File

@@ -1,121 +1,134 @@
# https://github.com/eladrich/pixel2style2pixel # https://github.com/eladrich/pixel2style2pixel
from collections import namedtuple
import torch
from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
""" """
ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
""" """
from collections import namedtuple
import torch
from torch.nn import (AdaptiveAvgPool2d, BatchNorm2d, Conv2d, MaxPool2d,
Module, PReLU, ReLU, Sequential, Sigmoid)
class Flatten(Module): class Flatten(Module):
def forward(self, input):
return input.view(input.size(0), -1) def forward(self, input):
return input.view(input.size(0), -1)
def l2_norm(input, axis=1): def l2_norm(input, axis=1):
norm = torch.norm(input, 2, axis, True) norm = torch.norm(input, 2, axis, True)
output = torch.div(input, norm) output = torch.div(input, norm)
return output return output
class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
""" A named tuple describing a ResNet block. """ """ A named tuple describing a ResNet block. """
def get_block(in_channel, depth, num_units, stride=2): def get_block(in_channel, depth, num_units, stride=2):
return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] return [Bottleneck(in_channel, depth, stride)
] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
def get_blocks(num_layers): def get_blocks(num_layers):
if num_layers == 50: if num_layers == 50:
blocks = [ blocks = [
get_block(in_channel=64, depth=64, num_units=3), get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=4), get_block(in_channel=64, depth=128, num_units=4),
get_block(in_channel=128, depth=256, num_units=14), get_block(in_channel=128, depth=256, num_units=14),
get_block(in_channel=256, depth=512, num_units=3) get_block(in_channel=256, depth=512, num_units=3)
] ]
elif num_layers == 100: elif num_layers == 100:
blocks = [ blocks = [
get_block(in_channel=64, depth=64, num_units=3), get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=13), get_block(in_channel=64, depth=128, num_units=13),
get_block(in_channel=128, depth=256, num_units=30), get_block(in_channel=128, depth=256, num_units=30),
get_block(in_channel=256, depth=512, num_units=3) get_block(in_channel=256, depth=512, num_units=3)
] ]
elif num_layers == 152: elif num_layers == 152:
blocks = [ blocks = [
get_block(in_channel=64, depth=64, num_units=3), get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=8), get_block(in_channel=64, depth=128, num_units=8),
get_block(in_channel=128, depth=256, num_units=36), get_block(in_channel=128, depth=256, num_units=36),
get_block(in_channel=256, depth=512, num_units=3) get_block(in_channel=256, depth=512, num_units=3)
] ]
else: else:
raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) raise ValueError(
return blocks 'Invalid number of layers: {}. Must be one of [50, 100, 152]'.
format(num_layers))
return blocks
class SEModule(Module): class SEModule(Module):
def __init__(self, channels, reduction):
super(SEModule, self).__init__()
self.avg_pool = AdaptiveAvgPool2d(1)
self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
self.relu = ReLU(inplace=True)
self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
self.sigmoid = Sigmoid()
def forward(self, x): def __init__(self, channels, reduction):
module_input = x super(SEModule, self).__init__()
x = self.avg_pool(x) self.avg_pool = AdaptiveAvgPool2d(1)
x = self.fc1(x) self.fc1 = Conv2d(
x = self.relu(x) channels,
x = self.fc2(x) channels // reduction,
x = self.sigmoid(x) kernel_size=1,
return module_input * x padding=0,
bias=False)
self.relu = ReLU(inplace=True)
self.fc2 = Conv2d(
channels // reduction,
channels,
kernel_size=1,
padding=0,
bias=False)
self.sigmoid = Sigmoid()
def forward(self, x):
module_input = x
x = self.avg_pool(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.sigmoid(x)
return module_input * x
class bottleneck_IR(Module): class bottleneck_IR(Module):
def __init__(self, in_channel, depth, stride):
super(bottleneck_IR, self).__init__()
if in_channel == depth:
self.shortcut_layer = MaxPool2d(1, stride)
else:
self.shortcut_layer = Sequential(
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
BatchNorm2d(depth)
)
self.res_layer = Sequential(
BatchNorm2d(in_channel),
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
)
def forward(self, x): def __init__(self, in_channel, depth, stride):
shortcut = self.shortcut_layer(x) super(bottleneck_IR, self).__init__()
res = self.res_layer(x) if in_channel == depth:
return res + shortcut self.shortcut_layer = MaxPool2d(1, stride)
else:
self.shortcut_layer = Sequential(
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
BatchNorm2d(depth))
self.res_layer = Sequential(
BatchNorm2d(in_channel),
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
PReLU(depth), Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
BatchNorm2d(depth))
def forward(self, x):
shortcut = self.shortcut_layer(x)
res = self.res_layer(x)
return res + shortcut
class bottleneck_IR_SE(Module): class bottleneck_IR_SE(Module):
def __init__(self, in_channel, depth, stride):
super(bottleneck_IR_SE, self).__init__()
if in_channel == depth:
self.shortcut_layer = MaxPool2d(1, stride)
else:
self.shortcut_layer = Sequential(
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
BatchNorm2d(depth)
)
self.res_layer = Sequential(
BatchNorm2d(in_channel),
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
PReLU(depth),
Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
BatchNorm2d(depth),
SEModule(depth, 16)
)
def forward(self, x): def __init__(self, in_channel, depth, stride):
shortcut = self.shortcut_layer(x) super(bottleneck_IR_SE, self).__init__()
res = self.res_layer(x) if in_channel == depth:
return res + shortcut self.shortcut_layer = MaxPool2d(1, stride)
else:
self.shortcut_layer = Sequential(
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
BatchNorm2d(depth))
self.res_layer = Sequential(
BatchNorm2d(in_channel),
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
PReLU(depth), Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
BatchNorm2d(depth), SEModule(depth, 16))
def forward(self, x):
shortcut = self.shortcut_layer(x)
res = self.res_layer(x)
return res + shortcut

View File

@@ -1,22 +1,26 @@
# https://github.com/eladrich/pixel2style2pixel # https://github.com/eladrich/pixel2style2pixel
import torch import torch
from torch import nn from torch import nn
from modelscope.models.cv.image_to_3d.ldm.thirdp.psp.model_irse import Backbone from modelscope.models.cv.image_to_3d.ldm.thirdp.psp.model_irse import Backbone
class IDFeatures(nn.Module): class IDFeatures(nn.Module):
def __init__(self, model_path): def __init__(self, model_path):
super(IDFeatures, self).__init__() super(IDFeatures, self).__init__()
print('Loading ResNet ArcFace') print('Loading ResNet ArcFace')
self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') self.facenet = Backbone(
self.facenet.load_state_dict(torch.load(model_path, map_location="cpu")) input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
self.facenet.load_state_dict(
torch.load(model_path, map_location='cpu'))
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
self.facenet.eval() self.facenet.eval()
def forward(self, x, crop=False): def forward(self, x, crop=False):
# Not sure of the image range here # Not sure of the image range here
if crop: if crop:
x = torch.nn.functional.interpolate(x, (256, 256), mode="area") x = torch.nn.functional.interpolate(x, (256, 256), mode='area')
x = x[:, :, 35:223, 32:220] x = x[:, :, 35:223, 32:220]
x = self.face_pool(x) x = self.face_pool(x)
x_feats = self.facenet(x) x_feats = self.facenet(x)

View File

@@ -1,86 +1,97 @@
# https://github.com/eladrich/pixel2style2pixel # https://github.com/eladrich/pixel2style2pixel
from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module from torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear,
from modelscope.models.cv.image_to_3d.ldm.thirdp.psp.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm Module, PReLU, Sequential)
""" from modelscope.models.cv.image_to_3d.ldm.thirdp.psp.helpers import (
Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) Flatten, bottleneck_IR, bottleneck_IR_SE, get_blocks, l2_norm)
"""
class Backbone(Module): class Backbone(Module):
def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): """
super(Backbone, self).__init__() Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
assert input_size in [112, 224], "input_size should be 112 or 224" """
assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
blocks = get_blocks(num_layers)
if mode == 'ir':
unit_module = bottleneck_IR
elif mode == 'ir_se':
unit_module = bottleneck_IR_SE
self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
BatchNorm2d(64),
PReLU(64))
if input_size == 112:
self.output_layer = Sequential(BatchNorm2d(512),
Dropout(drop_ratio),
Flatten(),
Linear(512 * 7 * 7, 512),
BatchNorm1d(512, affine=affine))
else:
self.output_layer = Sequential(BatchNorm2d(512),
Dropout(drop_ratio),
Flatten(),
Linear(512 * 14 * 14, 512),
BatchNorm1d(512, affine=affine))
modules = [] def __init__(self,
for block in blocks: input_size,
for bottleneck in block: num_layers,
modules.append(unit_module(bottleneck.in_channel, mode='ir',
bottleneck.depth, drop_ratio=0.4,
bottleneck.stride)) affine=True):
self.body = Sequential(*modules) super(Backbone, self).__init__()
assert input_size in [112, 224], 'input_size should be 112 or 224'
assert num_layers in [50, 100,
152], 'num_layers should be 50, 100 or 152'
assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
blocks = get_blocks(num_layers)
if mode == 'ir':
unit_module = bottleneck_IR
elif mode == 'ir_se':
unit_module = bottleneck_IR_SE
self.input_layer = Sequential(
Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64),
PReLU(64))
if input_size == 112:
self.output_layer = Sequential(
BatchNorm2d(512), Dropout(drop_ratio), Flatten(),
Linear(512 * 7 * 7, 512), BatchNorm1d(512, affine=affine))
else:
self.output_layer = Sequential(
BatchNorm2d(512), Dropout(drop_ratio), Flatten(),
Linear(512 * 14 * 14, 512), BatchNorm1d(512, affine=affine))
def forward(self, x): modules = []
x = self.input_layer(x) for block in blocks:
x = self.body(x) for bottleneck in block:
x = self.output_layer(x) modules.append(
return l2_norm(x) unit_module(bottleneck.in_channel, bottleneck.depth,
bottleneck.stride))
self.body = Sequential(*modules)
def forward(self, x):
x = self.input_layer(x)
x = self.body(x)
x = self.output_layer(x)
return l2_norm(x)
def IR_50(input_size): def IR_50(input_size):
"""Constructs a ir-50 model.""" """Constructs a ir-50 model."""
model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) model = Backbone(
return model input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
return model
def IR_101(input_size): def IR_101(input_size):
"""Constructs a ir-101 model.""" """Constructs a ir-101 model."""
model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) model = Backbone(
return model input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
return model
def IR_152(input_size): def IR_152(input_size):
"""Constructs a ir-152 model.""" """Constructs a ir-152 model."""
model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) model = Backbone(
return model input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
return model
def IR_SE_50(input_size): def IR_SE_50(input_size):
"""Constructs a ir_se-50 model.""" """Constructs a ir_se-50 model."""
model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) model = Backbone(
return model input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
return model
def IR_SE_101(input_size): def IR_SE_101(input_size):
"""Constructs a ir_se-101 model.""" """Constructs a ir_se-101 model."""
model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) model = Backbone(
return model input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
return model
def IR_SE_152(input_size): def IR_SE_152(input_size):
"""Constructs a ir_se-152 model.""" """Constructs a ir_se-152 model."""
model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) model = Backbone(
return model input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
return model

View File

@@ -1,32 +1,24 @@
import importlib import importlib
import torchvision
import torch
from torch import optim
import numpy as np
from inspect import isfunction from inspect import isfunction
from PIL import Image, ImageDraw, ImageFont
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
import time
import cv2 import cv2
import numpy as np
import PIL import PIL
import torch
from PIL import Image, ImageDraw, ImageFont
from torch import optim
def pil_rectangle_crop(im): def pil_rectangle_crop(im):
width, height = im.size # Get dimensions width, height = im.size # Get dimensions
if width <= height: if width <= height:
left = 0 left = 0
right = width right = width
top = (height - width)/2 top = (height - width) / 2
bottom = (height + width)/2 bottom = (height + width) / 2
else: else:
top = 0 top = 0
bottom = height bottom = height
left = (width - height) / 2 left = (width - height) / 2
@@ -36,6 +28,7 @@ def pil_rectangle_crop(im):
im = im.crop((left, top, right, bottom)) im = im.crop((left, top, right, bottom))
return im return im
def add_margin(pil_img, color=0, size=256): def add_margin(pil_img, color=0, size=256):
width, height = pil_img.size width, height = pil_img.size
result = Image.new(pil_img.mode, (size, size), color) result = Image.new(pil_img.mode, (size, size), color)
@@ -46,16 +39,17 @@ def add_margin(pil_img, color=0, size=256):
def create_carvekit_interface(): def create_carvekit_interface():
from carvekit.api.high import HiInterface from carvekit.api.high import HiInterface
# Check doc strings for more information # Check doc strings for more information
interface = HiInterface(object_type="object", # Can be "object" or "hairs-like". interface = HiInterface(
batch_size_seg=5, object_type='object', # Can be "object" or "hairs-like".
batch_size_matting=1, batch_size_seg=5,
device='cuda' if torch.cuda.is_available() else 'cpu', batch_size_matting=1,
seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net device='cuda' if torch.cuda.is_available() else 'cpu',
matting_mask_size=2048, seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
trimap_prob_threshold=231, matting_mask_size=2048,
trimap_dilation=30, trimap_prob_threshold=231,
trimap_erosion_iters=5, trimap_dilation=30,
fp16=False) trimap_erosion_iters=5,
fp16=False)
return interface return interface
@@ -72,17 +66,17 @@ def load_and_preprocess(interface, input_im):
image_without_background = np.array(image_without_background) image_without_background = np.array(image_without_background)
est_seg = image_without_background > 127 est_seg = image_without_background > 127
image = np.array(image) image = np.array(image)
foreground = est_seg[:, : , -1].astype(np.bool_) foreground = est_seg[:, :, -1].astype(np.bool_)
image[~foreground] = [255., 255., 255.] image[~foreground] = [255., 255., 255.]
x, y, w, h = cv2.boundingRect(foreground.astype(np.uint8)) x, y, w, h = cv2.boundingRect(foreground.astype(np.uint8))
image = image[y:y+h, x:x+w, :] image = image[y:y + h, x:x + w, :]
image = PIL.Image.fromarray(np.array(image)) image = PIL.Image.fromarray(np.array(image))
# resize image such that long edge is 512 # resize image such that long edge is 512
image.thumbnail([200, 200], Image.LANCZOS) image.thumbnail([200, 200], Image.LANCZOS)
image = add_margin(image, (255, 255, 255), size=256) image = add_margin(image, (255, 255, 255), size=256)
image = np.array(image) image = np.array(image)
return image return image
@@ -92,16 +86,17 @@ def log_txt_as_img(wh, xc, size=10):
b = len(xc) b = len(xc)
txts = list() txts = list()
for bi in range(b): for bi in range(b):
txt = Image.new("RGB", wh, color="white") txt = Image.new('RGB', wh, color='white')
draw = ImageDraw.Draw(txt) draw = ImageDraw.Draw(txt)
font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
nc = int(40 * (wh[0] / 256)) nc = int(40 * (wh[0] / 256))
lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) lines = '\n'.join(xc[bi][start:start + nc]
for start in range(0, len(xc[bi]), nc))
try: try:
draw.text((0, 0), lines, fill="black", font=font) draw.text((0, 0), lines, fill='black', font=font)
except UnicodeEncodeError: except UnicodeEncodeError:
print("Cant encode string for logging. Skipping.") print('Cant encode string for logging. Skipping.')
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt) txts.append(txt)
@@ -117,7 +112,7 @@ def ismap(x):
def isimage(x): def isimage(x):
if not isinstance(x,torch.Tensor): if not isinstance(x, torch.Tensor):
return False return False
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
@@ -143,22 +138,24 @@ def mean_flat(tensor):
def count_params(model, verbose=False): def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters()) total_params = sum(p.numel() for p in model.parameters())
if verbose: if verbose:
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") print(
f'{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.'
)
return total_params return total_params
def instantiate_from_config(config): def instantiate_from_config(config):
if not "target" in config: if 'target' not in config:
if config == '__is_first_stage__': if config == '__is_first_stage__':
return None return None
elif config == "__is_unconditional__": elif config == '__is_unconditional__':
return None return None
raise KeyError("Expected key `target` to instantiate.") raise KeyError('Expected key `target` to instantiate.')
return get_obj_from_str(config["target"])(**config.get("params", dict())) return get_obj_from_str(config['target'])(**config.get('params', dict()))
def get_obj_from_str(string, reload=False): def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1) module, cls = string.rsplit('.', 1)
print(module) print(module)
if reload: if reload:
module_imp = importlib.import_module(module) module_imp = importlib.import_module(module)
@@ -168,25 +165,42 @@ def get_obj_from_str(string, reload=False):
class AdamWwithEMAandWings(optim.Optimizer): class AdamWwithEMAandWings(optim.Optimizer):
# credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using def __init__(
weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code self,
ema_power=1., param_names=()): params,
lr=1.e-3,
betas=(0.9, 0.999),
eps=1.e-8, # TODO: check hyperparameters before using
weight_decay=1.e-2,
amsgrad=False,
ema_decay=0.9999, # ema decay to match previous code
ema_power=1.,
param_names=()): # noqa
"""AdamW that saves EMA versions of the parameters.""" """AdamW that saves EMA versions of the parameters."""
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError('Invalid learning rate: {}'.format(lr))
if not 0.0 <= eps: if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps)) raise ValueError('Invalid epsilon value: {}'.format(eps))
if not 0.0 <= betas[0] < 1.0: if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) raise ValueError('Invalid beta parameter at index 0: {}'.format(
betas[0]))
if not 0.0 <= betas[1] < 1.0: if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) raise ValueError('Invalid beta parameter at index 1: {}'.format(
betas[1]))
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) raise ValueError(
'Invalid weight_decay value: {}'.format(weight_decay))
if not 0.0 <= ema_decay <= 1.0: if not 0.0 <= ema_decay <= 1.0:
raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) raise ValueError('Invalid ema_decay value: {}'.format(ema_decay))
defaults = dict(lr=lr, betas=betas, eps=eps, defaults = dict(
weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, lr=lr,
ema_power=ema_power, param_names=param_names) betas=betas,
eps=eps,
weight_decay=weight_decay,
amsgrad=amsgrad,
ema_decay=ema_decay,
ema_power=ema_power,
param_names=param_names)
super().__init__(params, defaults) super().__init__(params, defaults)
def __setstate__(self, state): def __setstate__(self, state):
@@ -212,7 +226,7 @@ class AdamWwithEMAandWings(optim.Optimizer):
exp_avgs = [] exp_avgs = []
exp_avg_sqs = [] exp_avg_sqs = []
ema_params_with_grad = [] ema_params_with_grad = []
state_sums = [] # state_sums = []
max_exp_avg_sqs = [] max_exp_avg_sqs = []
state_steps = [] state_steps = []
amsgrad = group['amsgrad'] amsgrad = group['amsgrad']
@@ -225,7 +239,8 @@ class AdamWwithEMAandWings(optim.Optimizer):
continue continue
params_with_grad.append(p) params_with_grad.append(p)
if p.grad.is_sparse: if p.grad.is_sparse:
raise RuntimeError('AdamW does not support sparse gradients') raise RuntimeError(
'AdamW does not support sparse gradients')
grads.append(p.grad) grads.append(p.grad)
state = self.state[p] state = self.state[p]
@@ -234,12 +249,15 @@ class AdamWwithEMAandWings(optim.Optimizer):
if len(state) == 0: if len(state) == 0:
state['step'] = 0 state['step'] = 0
# Exponential moving average of gradient values # Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) state['exp_avg'] = torch.zeros_like(
p, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values # Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) state['exp_avg_sq'] = torch.zeros_like(
p, memory_format=torch.preserve_format)
if amsgrad: if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values # Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) state['max_exp_avg_sq'] = torch.zeros_like(
p, memory_format=torch.preserve_format)
# Exponential moving average of parameter values # Exponential moving average of parameter values
state['param_exp_avg'] = p.detach().float().clone() state['param_exp_avg'] = p.detach().float().clone()
@@ -255,22 +273,25 @@ class AdamWwithEMAandWings(optim.Optimizer):
# record the step after step update # record the step after step update
state_steps.append(state['step']) state_steps.append(state['step'])
optim._functional.adamw(params_with_grad, optim._functional.adamw(
grads, params_with_grad,
exp_avgs, grads,
exp_avg_sqs, exp_avgs,
max_exp_avg_sqs, exp_avg_sqs,
state_steps, max_exp_avg_sqs,
amsgrad=amsgrad, state_steps,
beta1=beta1, amsgrad=amsgrad,
beta2=beta2, beta1=beta1,
lr=group['lr'], beta2=beta2,
weight_decay=group['weight_decay'], lr=group['lr'],
eps=group['eps'], weight_decay=group['weight_decay'],
maximize=False) eps=group['eps'],
maximize=False)
cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) cur_ema_decay = min(ema_decay, 1 - state['step']**-ema_power)
for param, ema_param in zip(params_with_grad, ema_params_with_grad): for param, ema_param in zip(params_with_grad,
ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) ema_params_with_grad):
ema_param.mul_(cur_ema_decay).add_(
param.float(), alpha=1 - cur_ema_decay)
return loss return loss

View File

@@ -1,28 +1,27 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp import os.path as osp
from typing import Any, Dict from typing import Any, Dict
import rembg
import cv2 import cv2
import numpy as np import numpy as np
import PIL import PIL
import rembg
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torchvision.transforms as T import torchvision.transforms as T
import torchvision.transforms.functional as TF import torchvision.transforms.functional as TF
from omegaconf import OmegaConf
from PIL import Image from PIL import Image
from torchvision.utils import save_image from torchvision.utils import save_image
from omegaconf import OmegaConf
# import modelscope.models.cv.image_to_image_generation.data as data # import modelscope.models.cv.image_to_image_generation.data as data
# import modelscope.models.cv.image_to_image_generation.models as models # import modelscope.models.cv.image_to_image_generation.models as models
# import modelscope.models.cv.image_to_image_generation.ops as ops # import modelscope.models.cv.image_to_image_generation.ops as ops
from modelscope.metainfo import Pipelines from modelscope.metainfo import Pipelines
# from modelscope.models.cv.image_to_3d.model import UNet from modelscope.models.cv.image_to_3d.ldm.models.diffusion.sync_dreamer import \
# from modelscope.models.cv.image_to_image_generation.models.clip import \ SyncMultiviewDiffusion
# VisionTransformer from modelscope.models.cv.image_to_3d.ldm.util import (add_margin,
instantiate_from_config)
from modelscope.models.cv.image_to_3d.ldm.models.diffusion.sync_dreamer import SyncMultiviewDiffusion
from modelscope.models.cv.image_to_3d.ldm.util import instantiate_from_config, add_margin
from modelscope.outputs import OutputKeys from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES from modelscope.pipelines.builder import PIPELINES
@@ -31,23 +30,29 @@ from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger
# from modelscope.models.cv.image_to_3d.model import UNet
# from modelscope.models.cv.image_to_image_generation.models.clip import \
# VisionTransformer
logger = get_logger() logger = get_logger()
# Load Syncdreamer Model # Load Syncdreamer Model
def load_model(cfg, ckpt, strict=True): def load_model(cfg, ckpt, strict=True):
config = OmegaConf.load(cfg) config = OmegaConf.load(cfg)
model = instantiate_from_config(config.model) model = instantiate_from_config(config.model)
print(f'loading model from {ckpt} ...') print(f'loading model from {ckpt} ...')
ckpt = torch.load(ckpt,map_location='cpu') ckpt = torch.load(ckpt, map_location='cpu')
model.load_state_dict(ckpt['state_dict'],strict=strict) model.load_state_dict(ckpt['state_dict'], strict=strict)
model = model.cuda().eval() model = model.cuda().eval()
return model return model
# Prepare Syncdreamer Input # Prepare Syncdreamer Input
def prepare_inputs(image_input, elevation_input, crop_size=-1, image_size=256): def prepare_inputs(image_input, elevation_input, crop_size=-1, image_size=256):
image_input[:,:,:3] = image_input[:,:,:3][:,:,::-1] image_input[:, :, :3] = image_input[:, :, :3][:, :, ::-1]
image_input = Image.fromarray(image_input) image_input = Image.fromarray(image_input)
if crop_size!=-1: if crop_size != -1:
alpha_np = np.asarray(image_input)[:, :, 3] alpha_np = np.asarray(image_input)[:, :, 3]
coords = np.stack(np.nonzero(alpha_np), 1)[:, (1, 0)] coords = np.stack(np.nonzero(alpha_np), 1)[:, (1, 0)]
min_x, min_y = np.min(coords, 0) min_x, min_y = np.min(coords, 0)
@@ -59,21 +64,26 @@ def prepare_inputs(image_input, elevation_input, crop_size=-1, image_size=256):
ref_img_ = ref_img_.resize((w_, h_), resample=Image.BICUBIC) ref_img_ = ref_img_.resize((w_, h_), resample=Image.BICUBIC)
image_input = add_margin(ref_img_, size=image_size) image_input = add_margin(ref_img_, size=image_size)
else: else:
image_input = add_margin(image_input, size=max(image_input.height, image_input.width)) image_input = add_margin(
image_input = image_input.resize((image_size, image_size), resample=Image.BICUBIC) image_input, size=max(image_input.height, image_input.width))
image_input = image_input.resize((image_size, image_size),
resample=Image.BICUBIC)
image_input = np.asarray(image_input) image_input = np.asarray(image_input)
image_input = image_input.astype(np.float32) / 255.0 image_input = image_input.astype(np.float32) / 255.0
ref_mask = image_input[:, :, 3:] ref_mask = image_input[:, :, 3:]
image_input[:, :, :3] = image_input[:, :, :3] * ref_mask + 1 - ref_mask # white background image_input[:, :, :
3] = image_input[:, :, :
3] * ref_mask + 1 - ref_mask # white background
image_input = image_input[:, :, :3] * 2.0 - 1.0 image_input = image_input[:, :, :3] * 2.0 - 1.0
image_input = torch.from_numpy(image_input.astype(np.float32)) image_input = torch.from_numpy(image_input.astype(np.float32))
elevation_input = torch.from_numpy(np.asarray([np.deg2rad(elevation_input)], np.float32)) elevation_input = torch.from_numpy(
return {"input_image": image_input, "input_elevation": elevation_input} np.asarray([np.deg2rad(elevation_input)], np.float32))
return {'input_image': image_input, 'input_elevation': elevation_input}
@PIPELINES.register_module( @PIPELINES.register_module(
Tasks.image_to_3d, Tasks.image_to_3d, module_name=Pipelines.image_to_3d)
module_name=Pipelines.image_to_3d)
class Image23DPipeline(Pipeline): class Image23DPipeline(Pipeline):
def __init__(self, model: str, **kwargs): def __init__(self, model: str, **kwargs):
@@ -91,23 +101,28 @@ class Image23DPipeline(Pipeline):
self._device = torch.device('cuda') self._device = torch.device('cuda')
else: else:
self._device = torch.device('cpu') self._device = torch.device('cpu')
ckpt = config_path.replace("configuration.json", "syncdreamer-pretrain.ckpt") ckpt = config_path.replace('configuration.json',
self.model = load_model(config_path.replace("configuration.json", "syncdreamer.yaml"), ckpt).to(self._device) 'syncdreamer-pretrain.ckpt')
self.model = load_model(
config_path.replace('configuration.json', 'syncdreamer.yaml'),
ckpt).to(self._device)
# os.system("pip install -r {}".format(config_path.replace("configuration.json", "requirements.txt"))) # os.system("pip install -r {}".format(config_path.replace("configuration.json", "requirements.txt")))
# assert isinstance(self.model, SyncMultiviewDiffusion) # assert isinstance(self.model, SyncMultiviewDiffusion)
def preprocess(self, input: Input) -> Dict[str, Any]: def preprocess(self, input: Input) -> Dict[str, Any]:
result = rembg.remove(Image.open(input)) result = rembg.remove(Image.open(input))
print(type(result)) print(type(result))
img = np.array(result) img = np.array(result)
img[:,:,:3] = img[:,:,:3][:,:,::-1] img[:, :, :3] = img[:, :, :3][:, :, ::-1]
# img = cv2.imread(input) # img = cv2.imread(input)
data = prepare_inputs(img, elevation_input=10, crop_size=200, image_size=256) data = prepare_inputs(
img, elevation_input=10, crop_size=200, image_size=256)
for k,v in data.items():
for k, v in data.items():
data[k] = v.unsqueeze(0).cuda() data[k] = v.unsqueeze(0).cuda()
data[k] = torch.repeat_interleave(data[k], 1, dim=0) # only one sample data[k] = torch.repeat_interleave(
data[k], 1, dim=0) # only one sample
return data return data
@torch.no_grad() @torch.no_grad()
@@ -115,11 +130,11 @@ class Image23DPipeline(Pipeline):
x_sample = self.model.sample(input, 2.0, 8) x_sample = self.model.sample(input, 2.0, 8)
B, N, _, H, W = x_sample.shape B, N, _, H, W = x_sample.shape
x_sample = (torch.clamp(x_sample,max=1.0,min=-1.0) + 1) * 0.5 x_sample = (torch.clamp(x_sample, max=1.0, min=-1.0) + 1) * 0.5
x_sample = x_sample.permute(0,1,3,4,2).cpu().numpy() * 255 x_sample = x_sample.permute(0, 1, 3, 4, 2).cpu().numpy() * 255
x_sample = x_sample.astype(np.uint8) x_sample = x_sample.astype(np.uint8)
show_in_im2 = [Image.fromarray(x_sample[0,ni]) for ni in range(N)] show_in_im2 = [Image.fromarray(x_sample[0, ni]) for ni in range(N)]
return {'MViews':show_in_im2} return {'MViews': show_in_im2}
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs return inputs

View File

@@ -3,6 +3,7 @@ import unittest
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from modelscope.outputs import OutputKeys from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline from modelscope.pipelines import pipeline
from modelscope.pipelines.base import Pipeline from modelscope.pipelines.base import Pipeline
@@ -24,11 +25,11 @@ class ImageTo3DTest(unittest.TestCase):
def pipeline_inference(self, pipeline: Pipeline, input: str): def pipeline_inference(self, pipeline: Pipeline, input: str):
result = pipeline(input['input_path']) result = pipeline(input['input_path'])
np_content = [] np_content = []
for idx,img in enumerate(result['MViews']): for idx, img in enumerate(result['MViews']):
np_content.append(np.array(result['MViews'][idx])) np_content.append(np.array(result['MViews'][idx]))
np_content = np.concatenate(np_content, axis=1) np_content = np.concatenate(np_content, axis=1)
Image.fromarray(np_content).save("./concat.png") Image.fromarray(np_content).save('./concat.png')
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_modelhub(self): def test_run_modelhub(self):
@@ -38,4 +39,4 @@ class ImageTo3DTest(unittest.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()