diff --git a/modelscope/models/cv/anydoor/cldm/__init__.py b/modelscope/models/cv/anydoor/cldm/__init__.py index 8b137891..e69de29b 100644 --- a/modelscope/models/cv/anydoor/cldm/__init__.py +++ b/modelscope/models/cv/anydoor/cldm/__init__.py @@ -1 +0,0 @@ - diff --git a/modelscope/models/cv/anydoor/ldm/__init__.py b/modelscope/models/cv/anydoor/ldm/__init__.py index 8b137891..e69de29b 100644 --- a/modelscope/models/cv/anydoor/ldm/__init__.py +++ b/modelscope/models/cv/anydoor/ldm/__init__.py @@ -1 +0,0 @@ - diff --git a/modelscope/models/cv/image_to_3d/__init__.py b/modelscope/models/cv/image_to_3d/__init__.py index b41515ef..44c42428 100644 --- a/modelscope/models/cv/image_to_3d/__init__.py +++ b/modelscope/models/cv/image_to_3d/__init__.py @@ -1,2 +1,2 @@ # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. -from . import ldm \ No newline at end of file +from . import ldm diff --git a/modelscope/models/cv/image_to_3d/ldm/base_utils.py b/modelscope/models/cv/image_to_3d/ldm/base_utils.py index 6f4b6843..3362fa18 100644 --- a/modelscope/models/cv/image_to_3d/ldm/base_utils.py +++ b/modelscope/models/cv/image_to_3d/ldm/base_utils.py @@ -1,6 +1,7 @@ import pickle -import numpy as np + import cv2 +import numpy as np from skimage.io import imread @@ -9,16 +10,18 @@ def save_pickle(data, pkl_path): with open(pkl_path, 'wb') as f: pickle.dump(data, f) + def read_pickle(pkl_path): with open(pkl_path, 'rb') as f: return pickle.load(f) + def draw_epipolar_line(F, img0, img1, pt0, color): - h1,w1=img1.shape[:2] + h1, w1 = img1.shape[:2] hpt = np.asarray([pt0[0], pt0[1], 1], dtype=np.float32)[:, None] - l = F @ hpt - l = l[:, 0] - a, b, c = l[0], l[1], l[2] + ln = F @ hpt + ln = ln[:, 0] + a, b, c = ln[0], ln[1], ln[2] pt1 = np.asarray([0, -c / b]).astype(np.int32) pt2 = np.asarray([w1, (-a * w1 - c) / b]).astype(np.int32) @@ -26,8 +29,9 @@ def draw_epipolar_line(F, img0, img1, pt0, color): img1 = cv2.line(img1, tuple(pt1), tuple(pt2), color, 2) return img0, img1 -def draw_epipolar_lines(F, img0, img1,num=20): - img0,img1=img0.copy(),img1.copy() + +def draw_epipolar_lines(F, img0, img1, num=20): + img0, img1 = img0.copy(), img1.copy() h0, w0, _ = img0.shape h1, w1, _ = img1.shape @@ -42,117 +46,166 @@ def draw_epipolar_lines(F, img0, img1,num=20): return img0, img1 + def compute_F(K1, K2, Rt0, Rt1=None): if Rt1 is None: - R, t = Rt0[:,:3], Rt0[:,3:] + R, t = Rt0[:, :3], Rt0[:, 3:] else: - Rt = compute_dR_dt(Rt0,Rt1) - R, t = Rt[:,:3], Rt[:,3:] - A = K1 @ R.T @ t # [3,1] - C = np.asarray([[0,-A[2,0],A[1,0]], - [A[2,0],0,-A[0,0]], - [-A[1,0],A[0,0],0]]) + Rt = compute_dR_dt(Rt0, Rt1) + R, t = Rt[:, :3], Rt[:, 3:] + A = K1 @ R.T @ t # [3,1] + C = np.asarray([[0, -A[2, 0], A[1, 0]], [A[2, 0], 0, -A[0, 0]], + [-A[1, 0], A[0, 0], 0]]) F = (np.linalg.inv(K2)).T @ R @ K1.T @ C return F + def compute_dR_dt(Rt0, Rt1): - R0, t0 = Rt0[:,:3], Rt0[:,3:] - R1, t1 = Rt1[:,:3], Rt1[:,3:] + R0, t0 = Rt0[:, :3], Rt0[:, 3:] + R1, t1 = Rt1[:, :3], Rt1[:, 3:] dR = np.dot(R1, R0.T) dt = t1 - np.dot(dR, t0) return np.concatenate([dR, dt], -1) -def concat_images(img0,img1,vert=False): + +def concat_images(img0, img1, vert=False): if not vert: - h0,h1=img0.shape[0],img1.shape[0], - if h00) - if np.sum(mask0)>0: dpt[mask0]=1e-4 - mask1=(np.abs(dpt) > -1e-4) & (np.abs(dpt) < 0) - if np.sum(mask1)>0: dpt[mask1]=-1e-4 - pts2d = pts[:,:2]/dpt[:,None] + +def project_points(pts, RT, K): + pts = np.matmul(pts, RT[:, :3].transpose()) + RT[:, 3:].transpose() + pts = np.matmul(pts, K.transpose()) + dpt = pts[:, 2] + mask0 = (np.abs(dpt) < 1e-4) & (np.abs(dpt) > 0) + if np.sum(mask0) > 0: + dpt[mask0] = 1e-4 + mask1 = (np.abs(dpt) > -1e-4) & (np.abs(dpt) < 0) + if np.sum(mask1) > 0: + dpt[mask1] = -1e-4 + pts2d = pts[:, :2] / dpt[:, None] return pts2d, dpt def draw_keypoints(img, kps, colors=None, radius=2): - out_img=img.copy() + out_img = img.copy() for pi, pt in enumerate(kps): pt = np.round(pt).astype(np.int32) if colors is not None: - color=[int(c) for c in colors[pi]] + color = [int(c) for c in colors[pi]] cv2.circle(out_img, tuple(pt), radius, color, -1) else: - cv2.circle(out_img, tuple(pt), radius, (0,255,0), -1) + cv2.circle(out_img, tuple(pt), radius, (0, 255, 0), -1) return out_img -def output_points(fn,pts,colors=None): +def output_points(fn, pts, colors=None): with open(fn, 'w') as f: for pi, pt in enumerate(pts): f.write(f'{pt[0]:.6f} {pt[1]:.6f} {pt[2]:.6f} ') if colors is not None: - f.write(f'{int(colors[pi,0])} {int(colors[pi,1])} {int(colors[pi,2])}') + f.write( + f'{int(colors[pi,0])} {int(colors[pi,1])} {int(colors[pi,2])}' + ) f.write('\n') + DEPTH_MAX, DEPTH_MIN = 2.4, 0.6 DEPTH_VALID_MAX, DEPTH_VALID_MIN = 2.37, 0.63 + + def read_depth_objaverse(depth_fn): depth = imread(depth_fn) - depth = depth.astype(np.float32) / 65535 * (DEPTH_MAX-DEPTH_MIN) + DEPTH_MIN + depth = depth.astype( + np.float32) / 65535 * (DEPTH_MAX - DEPTH_MIN) + DEPTH_MIN mask = (depth > DEPTH_VALID_MIN) & (depth < DEPTH_VALID_MAX) return depth, mask -def mask_depth_to_pts(mask,depth,K,rgb=None): - hs,ws=np.nonzero(mask) - depth=depth[hs,ws] - pts=np.asarray([ws,hs,depth],np.float32).transpose() - pts[:,:2]*=pts[:,2:] +def mask_depth_to_pts(mask, depth, K, rgb=None): + hs, ws = np.nonzero(mask) + depth = depth[hs, ws] + pts = np.asarray([ws, hs, depth], np.float32).transpose() + pts[:, :2] *= pts[:, 2:] if rgb is not None: - return np.dot(pts, np.linalg.inv(K).transpose()), rgb[hs,ws] + return np.dot(pts, np.linalg.inv(K).transpose()), rgb[hs, ws] else: return np.dot(pts, np.linalg.inv(K).transpose()) + def transform_points_pose(pts, pose): R, t = pose[:, :3], pose[:, 3] - if len(pts.shape)==1: - return (R @ pts[:,None] + t[:,None])[:,0] - return pts @ R.T + t[None,:] + if len(pts.shape) == 1: + return (R @ pts[:, None] + t[:, None])[:, 0] + return pts @ R.T + t[None, :] -def pose_apply(pose,pts): + +def pose_apply(pose, pts): return transform_points_pose(pts, pose) + def downsample_gaussian_blur(img, ratio): sigma = (1 / ratio) / 3 # ksize=np.ceil(2*sigma) ksize = int(np.ceil(((sigma - 0.8) / 0.3 + 1) * 2 + 1)) ksize = ksize + 1 if ksize % 2 == 0 else ksize - img = cv2.GaussianBlur(img, (ksize, ksize), sigma, borderType=cv2.BORDER_REFLECT101) - return img \ No newline at end of file + img = cv2.GaussianBlur( + img, (ksize, ksize), sigma, borderType=cv2.BORDER_REFLECT101) + return img diff --git a/modelscope/models/cv/image_to_3d/ldm/models/autoencoder.py b/modelscope/models/cv/image_to_3d/ldm/models/autoencoder.py index 96b88d8a..6d5a538e 100644 --- a/modelscope/models/cv/image_to_3d/ldm/models/autoencoder.py +++ b/modelscope/models/cv/image_to_3d/ldm/models/autoencoder.py @@ -1,34 +1,36 @@ -import torch -import pytorch_lightning as pl -import torch.nn.functional as F from contextlib import contextmanager +import pytorch_lightning as pl +import torch +import torch.nn.functional as F from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer -from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.model import Encoder, Decoder -from modelscope.models.cv.image_to_3d.ldm.modules.distributions.distributions import DiagonalGaussianDistribution - +from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.model import ( + Decoder, Encoder) +from modelscope.models.cv.image_to_3d.ldm.modules.distributions.distributions import \ + DiagonalGaussianDistribution from modelscope.models.cv.image_to_3d.ldm.util import instantiate_from_config class VQModel(pl.LightningModule): - def __init__(self, - ddconfig, - lossconfig, - n_embed, - embed_dim, - ckpt_path=None, - ignore_keys=[], - image_key="image", - colorize_nlabels=None, - monitor=None, - batch_resize_range=None, - scheduler_config=None, - lr_g_factor=1.0, - remap=None, - sane_index_shape=False, # tell vector quantizer to return indices as bhw - use_ema=False - ): + + def __init__( + self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key='image', + colorize_nlabels=None, + monitor=None, + batch_resize_range=None, + scheduler_config=None, + lr_g_factor=1.0, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + use_ema=False): super().__init__() self.embed_dim = embed_dim self.n_embed = n_embed @@ -36,24 +38,31 @@ class VQModel(pl.LightningModule): self.encoder = Encoder(**ddconfig) self.decoder = Decoder(**ddconfig) self.loss = instantiate_from_config(lossconfig) - self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, - remap=remap, - sane_index_shape=sane_index_shape) - self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) - self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.quantize = VectorQuantizer( + n_embed, + embed_dim, + beta=0.25, + remap=remap, + sane_index_shape=sane_index_shape) + self.quant_conv = torch.nn.Conv2d(ddconfig['z_channels'], embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, + ddconfig['z_channels'], 1) if colorize_nlabels is not None: - assert type(colorize_nlabels)==int - self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + assert type(colorize_nlabels) == int + self.register_buffer('colorize', + torch.randn(3, colorize_nlabels, 1, 1)) if monitor is not None: self.monitor = monitor self.batch_resize_range = batch_resize_range if self.batch_resize_range is not None: - print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.") + print( + f'{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.' + ) self.use_ema = use_ema if self.use_ema: self.model_ema = LitEma(self) - print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.') if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) @@ -66,28 +75,30 @@ class VQModel(pl.LightningModule): self.model_ema.store(self.parameters()) self.model_ema.copy_to(self) if context is not None: - print(f"{context}: Switched to EMA weights") + print(f'{context}: Switched to EMA weights') try: yield None finally: if self.use_ema: self.model_ema.restore(self.parameters()) if context is not None: - print(f"{context}: Restored training weights") + print(f'{context}: Restored training weights') def init_from_ckpt(self, path, ignore_keys=list()): - sd = torch.load(path, map_location="cpu")["state_dict"] + sd = torch.load(path, map_location='cpu')['state_dict'] keys = list(sd.keys()) for k in keys: for ik in ignore_keys: if k.startswith(ik): - print("Deleting key {} from state_dict.".format(k)) + print('Deleting key {} from state_dict.'.format(k)) del sd[k] missing, unexpected = self.load_state_dict(sd, strict=False) - print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + print( + f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys' + ) if len(missing) > 0: - print(f"Missing Keys: {missing}") - print(f"Unexpected Keys: {unexpected}") + print(f'Missing Keys: {missing}') + print(f'Unexpected Keys: {unexpected}') def on_train_batch_end(self, *args, **kwargs): if self.use_ema: @@ -115,7 +126,7 @@ class VQModel(pl.LightningModule): return dec def forward(self, input, return_pred_indices=False): - quant, diff, (_,_,ind) = self.encode(input) + quant, diff, (_, _, ind) = self.encode(input) dec = self.decode(quant) if return_pred_indices: return dec, diff, ind @@ -125,7 +136,8 @@ class VQModel(pl.LightningModule): x = batch[k] if len(x.shape) == 3: x = x[..., None] - x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + x = x.permute(0, 3, 1, + 2).to(memory_format=torch.contiguous_format).float() if self.batch_resize_range is not None: lower_size = self.batch_resize_range[0] upper_size = self.batch_resize_range[1] @@ -133,9 +145,10 @@ class VQModel(pl.LightningModule): # do the first few batches with max size to avoid later oom new_resize = upper_size else: - new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16)) + new_resize = np.random.choice( + np.arange(lower_size, upper_size + 16, 16)) if new_resize != x.shape[2]: - x = F.interpolate(x, size=new_resize, mode="bicubic") + x = F.interpolate(x, size=new_resize, mode='bicubic') x = x.detach() return x @@ -147,79 +160,123 @@ class VQModel(pl.LightningModule): if optimizer_idx == 0: # autoencode - aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train", - predicted_indices=ind) + aeloss, log_dict_ae = self.loss( + qloss, + x, + xrec, + optimizer_idx, + self.global_step, + last_layer=self.get_last_layer(), + split='train', + predicted_indices=ind) - self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + self.log_dict( + log_dict_ae, + prog_bar=False, + logger=True, + on_step=True, + on_epoch=True) return aeloss if optimizer_idx == 1: # discriminator - discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train") - self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) + discloss, log_dict_disc = self.loss( + qloss, + x, + xrec, + optimizer_idx, + self.global_step, + last_layer=self.get_last_layer(), + split='train') + self.log_dict( + log_dict_disc, + prog_bar=False, + logger=True, + on_step=True, + on_epoch=True) return discloss def validation_step(self, batch, batch_idx): log_dict = self._validation_step(batch, batch_idx) - with self.ema_scope(): - log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") + # with self.ema_scope(): + # log_dict_ema = self._validation_step( + # batch, batch_idx, suffix='_ema') return log_dict - def _validation_step(self, batch, batch_idx, suffix=""): + def _validation_step(self, batch, batch_idx, suffix=''): x = self.get_input(batch, self.image_key) xrec, qloss, ind = self(x, return_pred_indices=True) - aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, - self.global_step, - last_layer=self.get_last_layer(), - split="val"+suffix, - predicted_indices=ind - ) + aeloss, log_dict_ae = self.loss( + qloss, + x, + xrec, + 0, + self.global_step, + last_layer=self.get_last_layer(), + split='val' + suffix, + predicted_indices=ind) - discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, - self.global_step, - last_layer=self.get_last_layer(), - split="val"+suffix, - predicted_indices=ind - ) - rec_loss = log_dict_ae[f"val{suffix}/rec_loss"] - self.log(f"val{suffix}/rec_loss", rec_loss, - prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) - self.log(f"val{suffix}/aeloss", aeloss, - prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + discloss, log_dict_disc = self.loss( + qloss, + x, + xrec, + 1, + self.global_step, + last_layer=self.get_last_layer(), + split='val' + suffix, + predicted_indices=ind) + rec_loss = log_dict_ae[f'val{suffix}/rec_loss'] + self.log( + f'val{suffix}/rec_loss', + rec_loss, + prog_bar=True, + logger=True, + on_step=False, + on_epoch=True, + sync_dist=True) + self.log( + f'val{suffix}/aeloss', + aeloss, + prog_bar=True, + logger=True, + on_step=False, + on_epoch=True, + sync_dist=True) if version.parse(pl.__version__) >= version.parse('1.4.0'): - del log_dict_ae[f"val{suffix}/rec_loss"] + del log_dict_ae[f'val{suffix}/rec_loss'] self.log_dict(log_dict_ae) self.log_dict(log_dict_disc) return self.log_dict def configure_optimizers(self): lr_d = self.learning_rate - lr_g = self.lr_g_factor*self.learning_rate - print("lr_d", lr_d) - print("lr_g", lr_g) - opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ - list(self.decoder.parameters())+ - list(self.quantize.parameters())+ - list(self.quant_conv.parameters())+ - list(self.post_quant_conv.parameters()), - lr=lr_g, betas=(0.5, 0.9)) - opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), - lr=lr_d, betas=(0.5, 0.9)) + lr_g = self.lr_g_factor * self.learning_rate + print('lr_d', lr_d) + print('lr_g', lr_g) + opt_ae = torch.optim.Adam( + list(self.encoder.parameters()) + list(self.decoder.parameters()) + + list(self.quantize.parameters()) + + list(self.quant_conv.parameters()) + + list(self.post_quant_conv.parameters()), + lr=lr_g, + betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam( + self.loss.discriminator.parameters(), lr=lr_d, betas=(0.5, 0.9)) if self.scheduler_config is not None: scheduler = instantiate_from_config(self.scheduler_config) - print("Setting up LambdaLR scheduler...") + print('Setting up LambdaLR scheduler...') scheduler = [ { - 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule), + 'scheduler': + LambdaLR(opt_ae, lr_lambda=scheduler.schedule), 'interval': 'step', 'frequency': 1 }, { - 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule), + 'scheduler': + LambdaLR(opt_disc, lr_lambda=scheduler.schedule), 'interval': 'step', 'frequency': 1 }, @@ -235,7 +292,7 @@ class VQModel(pl.LightningModule): x = self.get_input(batch, self.image_key) x = x.to(self.device) if only_inputs: - log["inputs"] = x + log['inputs'] = x return log xrec, _ = self(x) if x.shape[1] > 3: @@ -243,25 +300,28 @@ class VQModel(pl.LightningModule): assert xrec.shape[1] > 3 x = self.to_rgb(x) xrec = self.to_rgb(xrec) - log["inputs"] = x - log["reconstructions"] = xrec + log['inputs'] = x + log['reconstructions'] = xrec if plot_ema: with self.ema_scope(): xrec_ema, _ = self(x) - if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) - log["reconstructions_ema"] = xrec_ema + if x.shape[1] > 3: + xrec_ema = self.to_rgb(xrec_ema) + log['reconstructions_ema'] = xrec_ema return log def to_rgb(self, x): - assert self.image_key == "segmentation" - if not hasattr(self, "colorize"): - self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + assert self.image_key == 'segmentation' + if not hasattr(self, 'colorize'): + self.register_buffer('colorize', + torch.randn(3, x.shape[1], 1, 1).to(x)) x = F.conv2d(x, weight=self.colorize) - x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. return x class VQModelInterface(VQModel): + def __init__(self, embed_dim, *args, **kwargs): super().__init__(embed_dim=embed_dim, *args, **kwargs) self.embed_dim = embed_dim @@ -283,43 +343,48 @@ class VQModelInterface(VQModel): class AutoencoderKL(pl.LightningModule): - def __init__(self, - ddconfig, - lossconfig, - embed_dim, - ckpt_path=None, - ignore_keys=[], - image_key="image", - colorize_nlabels=None, - monitor=None, - ): + + def __init__( + self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key='image', + colorize_nlabels=None, + monitor=None, + ): super().__init__() self.image_key = image_key self.encoder = Encoder(**ddconfig) self.decoder = Decoder(**ddconfig) self.loss = instantiate_from_config(lossconfig) - assert ddconfig["double_z"] - self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) - self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + assert ddconfig['double_z'] + self.quant_conv = torch.nn.Conv2d(2 * ddconfig['z_channels'], + 2 * embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, + ddconfig['z_channels'], 1) self.embed_dim = embed_dim if colorize_nlabels is not None: - assert type(colorize_nlabels)==int - self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + assert type(colorize_nlabels) == int + self.register_buffer('colorize', + torch.randn(3, colorize_nlabels, 1, 1)) if monitor is not None: self.monitor = monitor if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) def init_from_ckpt(self, path, ignore_keys=list()): - sd = torch.load(path, map_location="cpu")["state_dict"] + sd = torch.load(path, map_location='cpu')['state_dict'] keys = list(sd.keys()) for k in keys: for ik in ignore_keys: if k.startswith(ik): - print("Deleting key {} from state_dict.".format(k)) + print('Deleting key {} from state_dict.'.format(k)) del sd[k] self.load_state_dict(sd, strict=False) - print(f"Restored from {path}") + print(f'Restored from {path}') def encode(self, x): h = self.encoder(x) @@ -345,7 +410,8 @@ class AutoencoderKL(pl.LightningModule): x = batch[k] if len(x.shape) == 3: x = x[..., None] - x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + x = x.permute(0, 3, 1, + 2).to(memory_format=torch.contiguous_format).float() return x def training_step(self, batch, batch_idx, optimizer_idx): @@ -354,44 +420,91 @@ class AutoencoderKL(pl.LightningModule): if optimizer_idx == 0: # train encoder+decoder+logvar - aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train") - self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) - self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + aeloss, log_dict_ae = self.loss( + inputs, + reconstructions, + posterior, + optimizer_idx, + self.global_step, + last_layer=self.get_last_layer(), + split='train') + self.log( + 'aeloss', + aeloss, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True) + self.log_dict( + log_dict_ae, + prog_bar=False, + logger=True, + on_step=True, + on_epoch=False) return aeloss if optimizer_idx == 1: # train the discriminator - discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train") + discloss, log_dict_disc = self.loss( + inputs, + reconstructions, + posterior, + optimizer_idx, + self.global_step, + last_layer=self.get_last_layer(), + split='train') - self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) - self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) + self.log( + 'discloss', + discloss, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True) + self.log_dict( + log_dict_disc, + prog_bar=False, + logger=True, + on_step=True, + on_epoch=False) return discloss def validation_step(self, batch, batch_idx): inputs = self.get_input(batch, self.image_key) reconstructions, posterior = self(inputs) - aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, - last_layer=self.get_last_layer(), split="val") + aeloss, log_dict_ae = self.loss( + inputs, + reconstructions, + posterior, + 0, + self.global_step, + last_layer=self.get_last_layer(), + split='val') - discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, - last_layer=self.get_last_layer(), split="val") + discloss, log_dict_disc = self.loss( + inputs, + reconstructions, + posterior, + 1, + self.global_step, + last_layer=self.get_last_layer(), + split='val') - self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) + self.log('val/rec_loss', log_dict_ae['val/rec_loss']) self.log_dict(log_dict_ae) self.log_dict(log_dict_disc) return self.log_dict def configure_optimizers(self): lr = self.learning_rate - opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ - list(self.decoder.parameters())+ - list(self.quant_conv.parameters())+ - list(self.post_quant_conv.parameters()), - lr=lr, betas=(0.5, 0.9)) - opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), - lr=lr, betas=(0.5, 0.9)) + opt_ae = torch.optim.Adam( + list(self.encoder.parameters()) + list(self.decoder.parameters()) + + list(self.quant_conv.parameters()) + + list(self.post_quant_conv.parameters()), + lr=lr, + betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam( + self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)) return [opt_ae, opt_disc], [] def get_last_layer(self): @@ -409,21 +522,23 @@ class AutoencoderKL(pl.LightningModule): assert xrec.shape[1] > 3 x = self.to_rgb(x) xrec = self.to_rgb(xrec) - log["samples"] = self.decode(torch.randn_like(posterior.sample())) - log["reconstructions"] = xrec - log["inputs"] = x + log['samples'] = self.decode(torch.randn_like(posterior.sample())) + log['reconstructions'] = xrec + log['inputs'] = x return log def to_rgb(self, x): - assert self.image_key == "segmentation" - if not hasattr(self, "colorize"): - self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + assert self.image_key == 'segmentation' + if not hasattr(self, 'colorize'): + self.register_buffer('colorize', + torch.randn(3, x.shape[1], 1, 1).to(x)) x = F.conv2d(x, weight=self.colorize) - x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. return x class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff super().__init__() diff --git a/modelscope/models/cv/image_to_3d/ldm/models/diffusion/sync_dreamer.py b/modelscope/models/cv/image_to_3d/ldm/models/diffusion/sync_dreamer.py index 90e25c13..9783ee5b 100644 --- a/modelscope/models/cv/image_to_3d/ldm/models/diffusion/sync_dreamer.py +++ b/modelscope/models/cv/image_to_3d/ldm/models/diffusion/sync_dreamer.py @@ -1,26 +1,33 @@ from pathlib import Path +import numpy as np import pytorch_lightning as pl import torch import torch.nn as nn import torch.nn.functional as F -import numpy as np from skimage.io import imsave from torch.optim.lr_scheduler import LambdaLR from tqdm import tqdm -from modelscope.models.cv.image_to_3d.ldm.base_utils import read_pickle, concat_images_list -from modelscope.models.cv.image_to_3d.ldm.models.diffusion.sync_dreamer_utils import get_warp_coordinates, create_target_volume -from modelscope.models.cv.image_to_3d.ldm.models.diffusion.sync_dreamer_network import NoisyTargetViewEncoder, SpatialTime3DNet, FrustumTV3DNet -from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.util import make_ddim_timesteps, timestep_embedding -from modelscope.models.cv.image_to_3d.ldm.modules.encoders.modules import FrozenCLIPImageEmbedder +from modelscope.models.cv.image_to_3d.ldm.base_utils import ( + concat_images_list, read_pickle) +from modelscope.models.cv.image_to_3d.ldm.models.diffusion.sync_dreamer_network import ( + FrustumTV3DNet, NoisyTargetViewEncoder, SpatialTime3DNet) +from modelscope.models.cv.image_to_3d.ldm.models.diffusion.sync_dreamer_utils import ( + create_target_volume, get_warp_coordinates) +from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.util import ( + make_ddim_timesteps, timestep_embedding) +from modelscope.models.cv.image_to_3d.ldm.modules.encoders.modules import \ + FrozenCLIPImageEmbedder from modelscope.models.cv.image_to_3d.ldm.util import instantiate_from_config + def disabled_train(self, mode=True): """Overwrite model.train with this function to make sure train/eval mode does not change anymore.""" return self + def disable_training_module(module: nn.Module): module = module.eval() module.train = disabled_train @@ -28,32 +35,39 @@ def disable_training_module(module: nn.Module): para.requires_grad = False return module + def repeat_to_batch(tensor, B, VN): t_shape = tensor.shape - ones = [1 for _ in range(len(t_shape)-1)] - tensor_new = tensor.view(B,1,*t_shape[1:]).repeat(1,VN,*ones).view(B*VN,*t_shape[1:]) + ones = [1 for _ in range(len(t_shape) - 1)] + tensor_new = tensor.view(B, 1, *t_shape[1:]).repeat(1, VN, *ones).view( + B * VN, *t_shape[1:]) return tensor_new + class UNetWrapper(nn.Module): - def __init__(self, diff_model_config, drop_conditions=False, drop_scheme='default', use_zero_123=True): + + def __init__(self, + diff_model_config, + drop_conditions=False, + drop_scheme='default', + use_zero_123=True): super().__init__() self.diffusion_model = instantiate_from_config(diff_model_config) self.drop_conditions = drop_conditions - self.drop_scheme=drop_scheme + self.drop_scheme = drop_scheme self.use_zero_123 = use_zero_123 - def drop(self, cond, mask): shape = cond.shape B = shape[0] - cond = mask.view(B,*[1 for _ in range(len(shape)-1)]) * cond + cond = mask.view(B, *[1 for _ in range(len(shape) - 1)]) * cond return cond def get_trainable_parameters(self): return self.diffusion_model.get_trainable_parameters() def get_drop_scheme(self, B, device): - if self.drop_scheme=='default': + if self.drop_scheme == 'default': random = torch.rand(B, dtype=torch.float32, device=device) drop_clip = (random > 0.15) & (random <= 0.2) drop_volume = (random > 0.1) & (random <= 0.15) @@ -63,7 +77,13 @@ class UNetWrapper(nn.Module): raise NotImplementedError return drop_clip, drop_volume, drop_concat, drop_all - def forward(self, x, t, clip_embed, volume_feats, x_concat, is_train=False): + def forward(self, + x, + t, + clip_embed, + volume_feats, + x_concat, + is_train=False): """ @param x: B,4,H,W @@ -76,7 +96,8 @@ class UNetWrapper(nn.Module): """ if self.drop_conditions and is_train: B = x.shape[0] - drop_clip, drop_volume, drop_concat, drop_all = self.get_drop_scheme(B, x.device) + drop_clip, drop_volume, drop_concat, drop_all = self.get_drop_scheme( + B, x.device) clip_mask = 1.0 - (drop_clip | drop_all).float() clip_embed = self.drop(clip_embed, clip_mask) @@ -100,7 +121,8 @@ class UNetWrapper(nn.Module): pred = self.diffusion_model(x, t, clip_embed, source_dict=volume_feats) return pred - def predict_with_unconditional_scale(self, x, t, clip_embed, volume_feats, x_concat, unconditional_scale): + def predict_with_unconditional_scale(self, x, t, clip_embed, volume_feats, + x_concat, unconditional_scale): x_ = torch.cat([x] * 2, 0) t_ = torch.cat([t] * 2, 0) clip_embed_ = torch.cat([clip_embed, torch.zeros_like(clip_embed)], 0) @@ -115,21 +137,34 @@ class UNetWrapper(nn.Module): first_stage_scale_factor = 0.18215 x_concat_[:, :4] = x_concat_[:, :4] / first_stage_scale_factor x_ = torch.cat([x_, x_concat_], 1) - s, s_uc = self.diffusion_model(x_, t_, clip_embed_, source_dict=v_).chunk(2) + s, s_uc = self.diffusion_model( + x_, t_, clip_embed_, source_dict=v_).chunk(2) s = s_uc + unconditional_scale * (s - s_uc) return s class SpatialVolumeNet(nn.Module): - def __init__(self, time_dim, view_dim, view_num, - input_image_size=256, frustum_volume_depth=48, - spatial_volume_size=32, spatial_volume_length=0.5, - frustum_volume_length=0.86603 # sqrt(3)/2 - ): + + def __init__( + self, + time_dim, + view_dim, + view_num, + input_image_size=256, + frustum_volume_depth=48, + spatial_volume_size=32, + spatial_volume_length=0.5, + frustum_volume_length=0.86603 # sqrt(3)/2 + ): super().__init__() - self.target_encoder = NoisyTargetViewEncoder(time_dim, view_dim, output_dim=16) - self.spatial_volume_feats = SpatialTime3DNet(input_dim=16 * view_num, time_dim=time_dim, dims=(64, 128, 256, 512)) - self.frustum_volume_feats = FrustumTV3DNet(64, time_dim, view_dim, dims=(64, 128, 256, 512)) + self.target_encoder = NoisyTargetViewEncoder( + time_dim, view_dim, output_dim=16) + self.spatial_volume_feats = SpatialTime3DNet( + input_dim=16 * view_num, + time_dim=time_dim, + dims=(64, 128, 256, 512)) + self.frustum_volume_feats = FrustumTV3DNet( + 64, time_dim, view_dim, dims=(64, 128, 256, 512)) self.frustum_volume_length = frustum_volume_length self.input_image_size = input_image_size @@ -140,9 +175,11 @@ class SpatialVolumeNet(nn.Module): self.frustum_volume_depth = frustum_volume_depth self.time_dim = time_dim self.view_dim = view_dim - self.default_origin_depth = 1.5 # our rendered images are 1.5 away from the origin, we assume camera is 1.5 away from the origin + # our rendered images are 1.5 away from the origin, we assume camera is 1.5 away from the origin + self.default_origin_depth = 1.5 - def construct_spatial_volume(self, x, t_embed, v_embed, target_poses, target_Ks): + def construct_spatial_volume(self, x, t_embed, v_embed, target_poses, + target_Ks): """ @param x: B,N,4,H,W @param t_embed: B,t_dim @@ -155,13 +192,23 @@ class SpatialVolumeNet(nn.Module): V = self.spatial_volume_size device = x.device - spatial_volume_verts = torch.linspace(-self.spatial_volume_length, self.spatial_volume_length, V, dtype=torch.float32, device=device) - spatial_volume_verts = torch.stack(torch.meshgrid(spatial_volume_verts, spatial_volume_verts, spatial_volume_verts), -1) - spatial_volume_verts = spatial_volume_verts.reshape(1, V ** 3, 3)[:, :, (2, 1, 0)] - spatial_volume_verts = spatial_volume_verts.view(1, V, V, V, 3).permute(0, 4, 1, 2, 3).repeat(B, 1, 1, 1, 1) + spatial_volume_verts = torch.linspace( + -self.spatial_volume_length, + self.spatial_volume_length, + V, + dtype=torch.float32, + device=device) + spatial_volume_verts = torch.stack( + torch.meshgrid(spatial_volume_verts, spatial_volume_verts, + spatial_volume_verts), -1) + spatial_volume_verts = spatial_volume_verts.reshape(1, V**3, + 3)[:, :, (2, 1, 0)] + spatial_volume_verts = spatial_volume_verts.view( + 1, V, V, V, 3).permute(0, 4, 1, 2, 3).repeat(B, 1, 1, 1, 1) # encode source features - t_embed_ = t_embed.view(B, 1, self.time_dim).repeat(1, N, 1).view(B, N, self.time_dim) + t_embed_ = t_embed.view(B, 1, self.time_dim).repeat(1, N, 1).view( + B, N, self.time_dim) # v_embed_ = v_embed.view(1, N, self.view_dim).repeat(B, 1, 1).view(B, N, self.view_dim) v_embed_ = v_embed target_Ks = target_Ks.unsqueeze(0).repeat(B, 1, 1, 1) @@ -173,22 +220,33 @@ class SpatialVolumeNet(nn.Module): for ni in range(0, N): pose_source_ = target_poses[:, ni] K_source_ = target_Ks[:, ni] - x_ = self.target_encoder(x[:, ni], t_embed_[:, ni], v_embed_[:, ni]) + x_ = self.target_encoder(x[:, ni], t_embed_[:, ni], v_embed_[:, + ni]) C = x_.shape[1] - coords_source = get_warp_coordinates(spatial_volume_verts, x_.shape[-1], self.input_image_size, K_source_, pose_source_).view(B, V, V * V, 2) - unproj_feats_ = F.grid_sample(x_, coords_source, mode='bilinear', padding_mode='zeros', align_corners=True) + coords_source = get_warp_coordinates( + spatial_volume_verts, x_.shape[-1], self.input_image_size, + K_source_, pose_source_).view(B, V, V * V, 2) + unproj_feats_ = F.grid_sample( + x_, + coords_source, + mode='bilinear', + padding_mode='zeros', + align_corners=True) unproj_feats_ = unproj_feats_.view(B, C, V, V, V) spatial_volume_feats.append(unproj_feats_) - spatial_volume_feats = torch.stack(spatial_volume_feats, 1) # B,N,C,V,V,V + spatial_volume_feats = torch.stack(spatial_volume_feats, + 1) # B,N,C,V,V,V N = spatial_volume_feats.shape[1] - spatial_volume_feats = spatial_volume_feats.view(B, N*C, V, V, V) + spatial_volume_feats = spatial_volume_feats.view(B, N * C, V, V, V) - spatial_volume_feats = self.spatial_volume_feats(spatial_volume_feats, t_embed) # b,64,32,32,32 + spatial_volume_feats = self.spatial_volume_feats( + spatial_volume_feats, t_embed) # b,64,32,32,32 return spatial_volume_feats - def construct_view_frustum_volume(self, spatial_volume, t_embed, v_embed, poses, Ks, target_indices): + def construct_view_frustum_volume(self, spatial_volume, t_embed, v_embed, + poses, Ks, target_indices): """ @param spatial_volume: B,C,V,V,V @param t_embed: B,t_dim @@ -203,34 +261,73 @@ class SpatialVolumeNet(nn.Module): D = self.frustum_volume_depth V = self.spatial_volume_size - near = torch.ones(B * TN, 1, H, W, dtype=spatial_volume.dtype, device=spatial_volume.device) * self.default_origin_depth - self.frustum_volume_length - far = torch.ones(B * TN, 1, H, W, dtype=spatial_volume.dtype, device=spatial_volume.device) * self.default_origin_depth + self.frustum_volume_length + near = torch.ones( + B * TN, + 1, + H, + W, + dtype=spatial_volume.dtype, + device=spatial_volume.device + ) * self.default_origin_depth - self.frustum_volume_length + far = torch.ones( + B * TN, + 1, + H, + W, + dtype=spatial_volume.dtype, + device=spatial_volume.device + ) * self.default_origin_depth + self.frustum_volume_length - target_indices = target_indices.view(B*TN) # B*TN - poses_ = poses[target_indices] # B*TN,3,4 - Ks_ = Ks[target_indices] # B*TN,3,4 - volume_xyz, volume_depth = create_target_volume(D, self.frustum_volume_size, self.input_image_size, poses_, Ks_, near, far) # B*TN,3 or 1,D,H,W + target_indices = target_indices.view(B * TN) # B*TN + poses_ = poses[target_indices] # B*TN,3,4 + Ks_ = Ks[target_indices] # B*TN,3,4 + volume_xyz, volume_depth = create_target_volume( + D, self.frustum_volume_size, self.input_image_size, poses_, Ks_, + near, far) # B*TN,3 or 1,D,H,W - volume_xyz_ = volume_xyz / self.spatial_volume_length # since the spatial volume is constructed in [-spatial_volume_length,spatial_volume_length] + # since the spatial volume is constructed in [-spatial_volume_length,spatial_volume_length] + volume_xyz_ = volume_xyz / self.spatial_volume_length volume_xyz_ = volume_xyz_.permute(0, 2, 3, 4, 1) # B*TN,D,H,W,3 - spatial_volume_ = spatial_volume.unsqueeze(1).repeat(1, TN, 1, 1, 1, 1).view(B * TN, -1, V, V, V) - volume_feats = F.grid_sample(spatial_volume_, volume_xyz_, mode='bilinear', padding_mode='zeros', align_corners=True) # B*TN,C,D,H,W + spatial_volume_ = spatial_volume.unsqueeze(1).repeat( + 1, TN, 1, 1, 1, 1).view(B * TN, -1, V, V, V) + volume_feats = F.grid_sample( + spatial_volume_, + volume_xyz_, + mode='bilinear', + padding_mode='zeros', + align_corners=True) # B*TN,C,D,H,W - v_embed_ = v_embed[torch.arange(B)[:,None], target_indices.view(B,TN)].view(B*TN, -1) # B*TN - t_embed_ = t_embed.unsqueeze(1).repeat(1,TN,1).view(B*TN,-1) - volume_feats_dict = self.frustum_volume_feats(volume_feats, t_embed_, v_embed_) + v_embed_ = v_embed[torch.arange(B)[:, None], + target_indices.view(B, TN)].view(B * TN, -1) # B*TN + t_embed_ = t_embed.unsqueeze(1).repeat(1, TN, 1).view(B * TN, -1) + volume_feats_dict = self.frustum_volume_feats(volume_feats, t_embed_, + v_embed_) return volume_feats_dict, volume_depth + + """ SyncDreamer is a SoTA Novel View Synthesis model which can generate 16 consistent views seamlessly. Please refer to: https://arxiv.org/abs/2309.03453 for more technique details. """ + + class SyncMultiviewDiffusion(pl.LightningModule): - def __init__(self, unet_config, scheduler_config, - finetune_unet=False, finetune_projection=True, - view_num=16, image_size=256, - cfg_scale=3.0, output_num=8, batch_view_num=4, - drop_conditions=False, drop_scheme='default', - clip_image_encoder_path="/apdcephfs/private_rondyliu/projects/clip/ViT-L-14.pt"): + + def __init__( + self, + unet_config, + scheduler_config, + finetune_unet=False, + finetune_projection=True, + view_num=16, + image_size=256, + cfg_scale=3.0, + output_num=8, + batch_view_num=4, + drop_conditions=False, + drop_scheme='default', + clip_image_encoder_path='/apdcephfs/private_rondyliu/projects/clip/ViT-L-14.pt' + ): super().__init__() self.finetune_unet = finetune_unet @@ -253,12 +350,18 @@ class SyncMultiviewDiffusion(pl.LightningModule): self._init_clip_image_encoder() self._init_clip_projection() - self.spatial_volume = SpatialVolumeNet(self.time_embed_dim, self.viewpoint_dim, self.view_num) - self.model = UNetWrapper(unet_config, drop_conditions=drop_conditions, drop_scheme=drop_scheme) + self.spatial_volume = SpatialVolumeNet(self.time_embed_dim, + self.viewpoint_dim, + self.view_num) + self.model = UNetWrapper( + unet_config, + drop_conditions=drop_conditions, + drop_scheme=drop_scheme) self.scheduler_config = scheduler_config - latent_size = image_size//8 - self.ddim = SyncDDIMSampler(self, 200, "uniform", 1.0, latent_size=latent_size) + latent_size = image_size // 8 + self.ddim = SyncDDIMSampler( + self, 200, 'uniform', 1.0, latent_size=latent_size) def _init_clip_projection(self): self.cc_projection = nn.Linear(772, 768) @@ -270,17 +373,21 @@ class SyncMultiviewDiffusion(pl.LightningModule): disable_training_module(self.cc_projection) def _init_multiview(self): - K, azs, _, _, poses = read_pickle(self.clip_image_encoder_path.replace("ViT-L-14.pt",f'camera-{self.view_num}.pkl')) + K, azs, _, _, poses = read_pickle( + self.clip_image_encoder_path.replace( + 'ViT-L-14.pt', f'camera-{self.view_num}.pkl')) default_image_size = 256 - ratio = self.image_size/default_image_size - K = np.diag([ratio,ratio,1]) @ K - K = torch.from_numpy(K.astype(np.float32)) # [3,3] - K = K.unsqueeze(0).repeat(self.view_num,1,1) # N,3,3 + ratio = self.image_size / default_image_size + K = np.diag([ratio, ratio, 1]) @ K + K = torch.from_numpy(K.astype(np.float32)) # [3,3] + K = K.unsqueeze(0).repeat(self.view_num, 1, 1) # N,3,3 poses = torch.from_numpy(poses.astype(np.float32)) # N,3,4 self.register_buffer('poses', poses) self.register_buffer('Ks', K) - azs = (azs + np.pi) % (np.pi * 2) - np.pi # scale to [-pi,pi] and the index=0 has az=0 - self.register_buffer('azimuth', torch.from_numpy(azs.astype(np.float32))) + azs = (azs + np.pi) % ( + np.pi * 2) - np.pi # scale to [-pi,pi] and the index=0 has az=0 + self.register_buffer('azimuth', + torch.from_numpy(azs.astype(np.float32))) def get_viewpoint_embedding(self, batch_size, elevation_ref): """ @@ -288,72 +395,90 @@ class SyncMultiviewDiffusion(pl.LightningModule): @param elevation_ref: B @return: """ - azimuth_input = self.azimuth[0].unsqueeze(0) # 1 - azimuth_target = self.azimuth # N - elevation_input = -elevation_ref # note that zero123 use a negative elevation here!!! + azimuth_input = self.azimuth[0].unsqueeze(0) # 1 + azimuth_target = self.azimuth # N + elevation_input = -elevation_ref # note that zero123 use a negative elevation here!!! elevation_target = -np.deg2rad(30) - d_e = elevation_target - elevation_input # B + d_e = elevation_target - elevation_input # B N = self.azimuth.shape[0] B = batch_size d_e = d_e.unsqueeze(1).repeat(1, N) - d_a = azimuth_target - azimuth_input # N + d_a = azimuth_target - azimuth_input # N d_a = d_a.unsqueeze(0).repeat(B, 1) d_z = torch.zeros_like(d_a) - embedding = torch.stack([d_e, torch.sin(d_a), torch.cos(d_a), d_z], -1) # B,N,4 + embedding = torch.stack( + [d_e, torch.sin(d_a), torch.cos(d_a), d_z], -1) # B,N,4 return embedding def _init_first_stage(self): - first_stage_config={ - "target": "modelscope.models.cv.image_to_3d.ldm.models.autoencoder.AutoencoderKL", - "params": { - "embed_dim": 4, - "monitor": "val/rec_loss", - "ddconfig":{ - "double_z": True, - "z_channels": 4, - "resolution": self.image_size, - "in_channels": 3, - "out_ch": 3, - "ch": 128, - "ch_mult": [1,2,4,4], - "num_res_blocks": 2, - "attn_resolutions": [], - "dropout": 0.0 + first_stage_config = { + 'target': + 'modelscope.models.cv.image_to_3d.ldm.models.autoencoder.AutoencoderKL', + 'params': { + 'embed_dim': 4, + 'monitor': 'val/rec_loss', + 'ddconfig': { + 'double_z': True, + 'z_channels': 4, + 'resolution': self.image_size, + 'in_channels': 3, + 'out_ch': 3, + 'ch': 128, + 'ch_mult': [1, 2, 4, 4], + 'num_res_blocks': 2, + 'attn_resolutions': [], + 'dropout': 0.0 + }, + 'lossconfig': { + 'target': 'torch.nn.Identity' }, - "lossconfig": {"target": "torch.nn.Identity"}, } } self.first_stage_scale_factor = 0.18215 self.first_stage_model = instantiate_from_config(first_stage_config) - self.first_stage_model = disable_training_module(self.first_stage_model) + self.first_stage_model = disable_training_module( + self.first_stage_model) def _init_clip_image_encoder(self): - self.clip_image_encoder = FrozenCLIPImageEmbedder(model=self.clip_image_encoder_path) - self.clip_image_encoder = disable_training_module(self.clip_image_encoder) + self.clip_image_encoder = FrozenCLIPImageEmbedder( + model=self.clip_image_encoder_path) + self.clip_image_encoder = disable_training_module( + self.clip_image_encoder) def _init_schedule(self): self.num_timesteps = 1000 linear_start = 0.00085 linear_end = 0.0120 num_timesteps = 1000 - betas = torch.linspace(linear_start ** 0.5, linear_end ** 0.5, num_timesteps, dtype=torch.float32) ** 2 # T + betas = torch.linspace( + linear_start**0.5, + linear_end**0.5, + num_timesteps, + dtype=torch.float32)**2 # T assert betas.shape[0] == self.num_timesteps # all in float64 first alphas = 1. - betas - alphas_cumprod = torch.cumprod(alphas, dim=0) # T - alphas_cumprod_prev = torch.cat([torch.ones(1, dtype=torch.float64), alphas_cumprod[:-1]], 0) - posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) # T - posterior_log_variance_clipped = torch.log(torch.clamp(posterior_variance, min=1e-20)) - posterior_log_variance_clipped = torch.clamp(posterior_log_variance_clipped, min=-10) + alphas_cumprod = torch.cumprod(alphas, dim=0) # T + alphas_cumprod_prev = torch.cat( + [torch.ones(1, dtype=torch.float64), alphas_cumprod[:-1]], 0) + posterior_variance = betas * (1. - alphas_cumprod_prev) / ( + 1. - alphas_cumprod) # T + posterior_log_variance_clipped = torch.log( + torch.clamp(posterior_variance, min=1e-20)) + posterior_log_variance_clipped = torch.clamp( + posterior_log_variance_clipped, min=-10) - self.register_buffer("betas", betas.float()) - self.register_buffer("alphas", alphas.float()) - self.register_buffer("alphas_cumprod", alphas_cumprod.float()) - self.register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod).float()) - self.register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1 - alphas_cumprod).float()) - self.register_buffer("posterior_variance", posterior_variance.float()) - self.register_buffer('posterior_log_variance_clipped', posterior_log_variance_clipped.float()) + self.register_buffer('betas', betas.float()) + self.register_buffer('alphas', alphas.float()) + self.register_buffer('alphas_cumprod', alphas_cumprod.float()) + self.register_buffer('sqrt_alphas_cumprod', + torch.sqrt(alphas_cumprod).float()) + self.register_buffer('sqrt_one_minus_alphas_cumprod', + torch.sqrt(1 - alphas_cumprod).float()) + self.register_buffer('posterior_variance', posterior_variance.float()) + self.register_buffer('posterior_log_variance_clipped', + posterior_log_variance_clipped.float()) def _init_time_step_embedding(self): self.time_embed_dim = 256 @@ -367,9 +492,11 @@ class SyncMultiviewDiffusion(pl.LightningModule): with torch.no_grad(): posterior = self.first_stage_model.encode(x) # b,4,h//8,w//8 if sample: - return posterior.sample().detach() * self.first_stage_scale_factor + return posterior.sample().detach( + ) * self.first_stage_scale_factor else: - return posterior.mode().detach() * self.first_stage_scale_factor + return posterior.mode().detach( + ) * self.first_stage_scale_factor def decode_first_stage(self, z): with torch.no_grad(): @@ -379,27 +506,37 @@ class SyncMultiviewDiffusion(pl.LightningModule): def prepare(self, batch): # encode target if 'target_image' in batch: - image_target = batch['target_image'].permute(0, 1, 4, 2, 3) # b,n,3,h,w + image_target = batch['target_image'].permute(0, 1, 4, 2, + 3) # b,n,3,h,w N = image_target.shape[1] - x = [self.encode_first_stage(image_target[:,ni], True) for ni in range(N)] - x = torch.stack(x, 1) # b,n,4,h//8,w//8 + x = [ + self.encode_first_stage(image_target[:, ni], True) + for ni in range(N) + ] + x = torch.stack(x, 1) # b,n,4,h//8,w//8 else: x = None image_input = batch['input_image'].permute(0, 3, 1, 2) - elevation_input = batch['input_elevation'][:, 0] # b + elevation_input = batch['input_elevation'][:, 0] # b x_input = self.encode_first_stage(image_input) - input_info = {'image': image_input, 'elevation': elevation_input, 'x': x_input} + input_info = { + 'image': image_input, + 'elevation': elevation_input, + 'x': x_input + } with torch.no_grad(): clip_embed = self.clip_image_encoder.encode(image_input) return x, clip_embed, input_info def embed_time(self, t): - t_embed = timestep_embedding(t, self.time_embed_dim, repeat_only=False) # B,TED - t_embed = self.time_embed(t_embed) # B,TED + t_embed = timestep_embedding( + t, self.time_embed_dim, repeat_only=False) # B,TED + t_embed = self.time_embed(t_embed) # B,TED return t_embed - def get_target_view_feats(self, x_input, spatial_volume, clip_embed, t_embed, v_embed, target_index): + def get_target_view_feats(self, x_input, spatial_volume, clip_embed, + t_embed, v_embed, target_index): """ @param x_input: B,4,H,W @param spatial_volume: B,C,V,V,V @@ -411,48 +548,91 @@ class SyncMultiviewDiffusion(pl.LightningModule): tensors of size B*TN,* """ B, _, H, W = x_input.shape - frustum_volume_feats, frustum_volume_depth = self.spatial_volume.construct_view_frustum_volume(spatial_volume, t_embed, v_embed, self.poses, self.Ks, target_index) + frustum_volume_feats, frustum_volume_depth = self.spatial_volume.construct_view_frustum_volume( + spatial_volume, t_embed, v_embed, self.poses, self.Ks, + target_index) # clip TN = target_index.shape[1] - v_embed_ = v_embed[torch.arange(B)[:,None], target_index].view(B*TN, self.viewpoint_dim) # B*TN,v_dim - clip_embed_ = clip_embed.unsqueeze(1).repeat(1,TN,1,1).view(B*TN,1,768) - clip_embed_ = self.cc_projection(torch.cat([clip_embed_, v_embed_.unsqueeze(1)], -1)) # B*TN,1,768 + v_embed_ = v_embed[torch.arange(B)[:, None], + target_index].view(B * TN, + self.viewpoint_dim) # B*TN,v_dim + clip_embed_ = clip_embed.unsqueeze(1).repeat(1, TN, 1, + 1).view(B * TN, 1, 768) + clip_embed_ = self.cc_projection( + torch.cat([clip_embed_, v_embed_.unsqueeze(1)], -1)) # B*TN,1,768 - x_input_ = x_input.unsqueeze(1).repeat(1, TN, 1, 1, 1).view(B * TN, 4, H, W) + x_input_ = x_input.unsqueeze(1).repeat(1, TN, 1, 1, + 1).view(B * TN, 4, H, W) x_concat = x_input_ return clip_embed_, frustum_volume_feats, x_concat def training_step(self, batch): B = batch['image'].shape[0] - time_steps = torch.randint(0, self.num_timesteps, (B,), device=self.device).long() + time_steps = torch.randint( + 0, self.num_timesteps, (B, ), device=self.device).long() x, clip_embed, input_info = self.prepare(batch) x_noisy, noise = self.add_noise(x, time_steps) # B,N,4,H,W N = self.view_num - target_index = torch.randint(0, N, (B, 1), device=self.device).long() # B, 1 - v_embed = self.get_viewpoint_embedding(B, input_info['elevation']) # N,v_dim + target_index = torch.randint( + 0, N, (B, 1), device=self.device).long() # B, 1 + v_embed = self.get_viewpoint_embedding( + B, input_info['elevation']) # N,v_dim t_embed = self.embed_time(time_steps) - spatial_volume = self.spatial_volume.construct_spatial_volume(x_noisy, t_embed, v_embed, self.poses, self.Ks) + spatial_volume = self.spatial_volume.construct_spatial_volume( + x_noisy, t_embed, v_embed, self.poses, self.Ks) - clip_embed, volume_feats, x_concat = self.get_target_view_feats(input_info['x'], spatial_volume, clip_embed, t_embed, v_embed, target_index) + clip_embed, volume_feats, x_concat = self.get_target_view_feats( + input_info['x'], spatial_volume, clip_embed, t_embed, v_embed, + target_index) - x_noisy_ = x_noisy[torch.arange(B)[:,None],target_index][:,0] # B,4,H,W - noise_predict = self.model(x_noisy_, time_steps, clip_embed, volume_feats, x_concat, is_train=True) # B,4,H,W + x_noisy_ = x_noisy[torch.arange(B)[:, None], + target_index][:, 0] # B,4,H,W + noise_predict = self.model( + x_noisy_, + time_steps, + clip_embed, + volume_feats, + x_concat, + is_train=True) # B,4,H,W - noise_target = noise[torch.arange(B)[:,None],target_index][:,0] # B,4,H,W + noise_target = noise[torch.arange(B)[:, None], + target_index][:, 0] # B,4,H,W # loss simple for diffusion - loss_simple = torch.nn.functional.mse_loss(noise_target, noise_predict, reduction='none') + loss_simple = torch.nn.functional.mse_loss( + noise_target, noise_predict, reduction='none') loss = loss_simple.mean() - self.log('sim', loss_simple.mean(), prog_bar=True, logger=True, on_step=True, on_epoch=True, rank_zero_only=True) + self.log( + 'sim', + loss_simple.mean(), + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True, + rank_zero_only=True) # log others lr = self.optimizers().param_groups[0]['lr'] - self.log('lr', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False, rank_zero_only=True) - self.log("step", self.global_step, prog_bar=True, logger=True, on_step=True, on_epoch=False, rank_zero_only=True) + self.log( + 'lr', + lr, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=False, + rank_zero_only=True) + self.log( + 'step', + self.global_step, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=False, + rank_zero_only=True) return loss def add_noise(self, x_start, t): @@ -462,65 +642,100 @@ class SyncMultiviewDiffusion(pl.LightningModule): @return: """ B = x_start.shape[0] - noise = torch.randn_like(x_start) # B,* + noise = torch.randn_like(x_start) # B,* - sqrt_alphas_cumprod_ = self.sqrt_alphas_cumprod[t] # B, - sqrt_one_minus_alphas_cumprod_ = self.sqrt_one_minus_alphas_cumprod[t] # B - sqrt_alphas_cumprod_ = sqrt_alphas_cumprod_.view(B, *[1 for _ in range(len(x_start.shape)-1)]) - sqrt_one_minus_alphas_cumprod_ = sqrt_one_minus_alphas_cumprod_.view(B, *[1 for _ in range(len(x_start.shape)-1)]) + sqrt_alphas_cumprod_ = self.sqrt_alphas_cumprod[t] # B, + sqrt_one_minus_alphas_cumprod_ = self.sqrt_one_minus_alphas_cumprod[ + t] # B + sqrt_alphas_cumprod_ = sqrt_alphas_cumprod_.view( + B, *[1 for _ in range(len(x_start.shape) - 1)]) + sqrt_one_minus_alphas_cumprod_ = sqrt_one_minus_alphas_cumprod_.view( + B, *[1 for _ in range(len(x_start.shape) - 1)]) x_noisy = sqrt_alphas_cumprod_ * x_start + sqrt_one_minus_alphas_cumprod_ * noise return x_noisy, noise - def sample(self, batch, cfg_scale, batch_view_num, use_ddim=True, - return_inter_results=False, inter_interval=50, inter_view_interval=2): + def sample(self, + batch, + cfg_scale, + batch_view_num, + use_ddim=True, + return_inter_results=False, + inter_interval=50, + inter_view_interval=2): _, clip_embed, input_info = self.prepare(batch) if use_ddim: - x_sample, inter = self.ddim.sample(input_info, clip_embed, unconditional_scale=cfg_scale, log_every_t=inter_interval, batch_view_num=batch_view_num) + x_sample, inter = self.ddim.sample( + input_info, + clip_embed, + unconditional_scale=cfg_scale, + log_every_t=inter_interval, + batch_view_num=batch_view_num) else: raise NotImplementedError N = x_sample.shape[1] - x_sample = torch.stack([self.decode_first_stage(x_sample[:, ni]) for ni in range(N)], 1) + x_sample = torch.stack( + [self.decode_first_stage(x_sample[:, ni]) for ni in range(N)], 1) if return_inter_results: torch.cuda.synchronize() torch.cuda.empty_cache() - inter = torch.stack(inter['x_inter'], 2) # # B,N,T,C,H,W - B,N,T,C,H,W = inter.shape + inter = torch.stack(inter['x_inter'], 2) # # B,N,T,C,H,W + B, N, T, C, H, W = inter.shape inter_results = [] for ni in tqdm(range(0, N, inter_view_interval)): inter_results_ = [] for ti in range(T): - inter_results_.append(self.decode_first_stage(inter[:, ni, ti])) - inter_results.append(torch.stack(inter_results_, 1)) # B,T,3,H,W - inter_results = torch.stack(inter_results,1) # B,N,T,3,H,W + inter_results_.append( + self.decode_first_stage(inter[:, ni, ti])) + inter_results.append(torch.stack(inter_results_, + 1)) # B,T,3,H,W + inter_results = torch.stack(inter_results, 1) # B,N,T,3,H,W return x_sample, inter_results else: return x_sample - def log_image(self, x_sample, batch, step, output_dir, only_first_row=False): - process = lambda x: ((torch.clip(x, min=-1, max=1).cpu().numpy() * 0.5 + 0.5) * 255).astype(np.uint8) + def log_image(self, + x_sample, + batch, + step, + output_dir, + only_first_row=False): + + def process(x): + return ((torch.clip(x, min=-1, max=1).cpu().numpy() * 0.5 + 0.5) + * 255).astype(np.uint8) + B = x_sample.shape[0] N = x_sample.shape[1] image_cond = [] for bi in range(B): - img_pr_ = concat_images_list(process(batch['ref_image'][bi]),*[process(x_sample[bi, ni].permute(1, 2, 0)) for ni in range(N)]) - img_gt_ = concat_images_list(process(batch['ref_image'][bi]),*[process(batch['image'][bi, ni]) for ni in range(N)]) - if not only_first_row or bi==0: - image_cond.append(concat_images_list(img_gt_, img_pr_, vert=True)) + img_pr_ = concat_images_list( + process(batch['ref_image'][bi]), *[ + process(x_sample[bi, ni].permute(1, 2, 0)) + for ni in range(N) + ]) + img_gt_ = concat_images_list( + process(batch['ref_image'][bi]), + *[process(batch['image'][bi, ni]) for ni in range(N)]) + if not only_first_row or bi == 0: + image_cond.append( + concat_images_list(img_gt_, img_pr_, vert=True)) else: image_cond.append(img_pr_) - output_dir = Path(output_dir) - imsave(str(output_dir/f'{step}.jpg'), concat_images_list(*image_cond, vert=True)) + imsave( + str(output_dir / f'{step}.jpg'), + concat_images_list(*image_cond, vert=True)) @torch.no_grad() def validation_step(self, batch, batch_idx): - if batch_idx==0 and self.global_rank==0: + if batch_idx == 0 and self.global_rank == 0: self.eval() step = self.global_step batch_ = {} - for k, v in batch.items(): batch_[k] = v[:self.output_num] + for k, v in batch.items(): + batch_[k] = v[:self.output_num] x_sample = self.sample(batch_, self.cfg_scale, self.batch_view_num) output_dir = Path(self.image_dir) / 'images' / 'val' output_dir.mkdir(exist_ok=True, parents=True) @@ -531,24 +746,49 @@ class SyncMultiviewDiffusion(pl.LightningModule): print(f'setting learning rate to {lr:.4f} ...') paras = [] if self.finetune_projection: - paras.append({"params": self.cc_projection.parameters(), "lr": lr},) + paras.append({ + 'params': self.cc_projection.parameters(), + 'lr': lr + }, ) if self.finetune_unet: - paras.append({"params": self.model.parameters(), "lr": lr},) + paras.append({'params': self.model.parameters(), 'lr': lr}, ) else: - paras.append({"params": self.model.get_trainable_parameters(), "lr": lr},) + paras.append( + { + 'params': self.model.get_trainable_parameters(), + 'lr': lr + }, ) - paras.append({"params": self.time_embed.parameters(), "lr": lr*10.0},) - paras.append({"params": self.spatial_volume.parameters(), "lr": lr*10.0},) + paras.append({ + 'params': self.time_embed.parameters(), + 'lr': lr * 10.0 + }, ) + paras.append( + { + 'params': self.spatial_volume.parameters(), + 'lr': lr * 10.0 + }, ) opt = torch.optim.AdamW(paras, lr=lr) scheduler = instantiate_from_config(self.scheduler_config) - print("Setting up LambdaLR scheduler...") - scheduler = [{'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), 'interval': 'step', 'frequency': 1}] + print('Setting up LambdaLR scheduler...') + scheduler = [{ + 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] return [opt], scheduler + class SyncDDIMSampler: - def __init__(self, model: SyncMultiviewDiffusion, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., latent_size=32): + + def __init__(self, + model: SyncMultiviewDiffusion, + ddim_num_steps, + ddim_discretize='uniform', + ddim_eta=0., + latent_size=32): super().__init__() self.model = model self.ddpm_num_timesteps = model.num_timesteps @@ -556,25 +796,43 @@ class SyncDDIMSampler: self._make_schedule(ddim_num_steps, ddim_discretize, ddim_eta) self.eta = ddim_eta - def _make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): - self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose) # DT - ddim_timesteps_ = torch.from_numpy(self.ddim_timesteps.astype(np.int64)) # DT + def _make_schedule(self, + ddim_num_steps, + ddim_discretize='uniform', + ddim_eta=0., + verbose=True): + self.ddim_timesteps = make_ddim_timesteps( + ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose) # DT + ddim_timesteps_ = torch.from_numpy( + self.ddim_timesteps.astype(np.int64)) # DT - alphas_cumprod = self.model.alphas_cumprod # T - assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' - self.ddim_alphas = alphas_cumprod[ddim_timesteps_].double() # DT - self.ddim_alphas_prev = torch.cat([alphas_cumprod[0:1], alphas_cumprod[ddim_timesteps_[:-1]]], 0) # DT - self.ddim_sigmas = ddim_eta * torch.sqrt((1 - self.ddim_alphas_prev) / (1 - self.ddim_alphas) * (1 - self.ddim_alphas / self.ddim_alphas_prev)) + alphas_cumprod = self.model.alphas_cumprod # T + assert alphas_cumprod.shape[ + 0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + self.ddim_alphas = alphas_cumprod[ddim_timesteps_].double() # DT + self.ddim_alphas_prev = torch.cat( + [alphas_cumprod[0:1], alphas_cumprod[ddim_timesteps_[:-1]]], + 0) # DT + self.ddim_sigmas = ddim_eta * torch.sqrt( # noqa + (1 - self.ddim_alphas_prev) / (1 - self.ddim_alphas) * # noqa + (1 - self.ddim_alphas / self.ddim_alphas_prev)) # noqa - self.ddim_alphas_raw = self.model.alphas[ddim_timesteps_].float() # DT + self.ddim_alphas_raw = self.model.alphas[ddim_timesteps_].float() # DT self.ddim_sigmas = self.ddim_sigmas.float() self.ddim_alphas = self.ddim_alphas.float() self.ddim_alphas_prev = self.ddim_alphas_prev.float() - self.ddim_sqrt_one_minus_alphas = torch.sqrt(1. - self.ddim_alphas).float() - + self.ddim_sqrt_one_minus_alphas = torch.sqrt( + 1. - self.ddim_alphas).float() @torch.no_grad() - def denoise_apply_impl(self, x_target_noisy, index, noise_pred, is_step0=False): + def denoise_apply_impl(self, + x_target_noisy, + index, + noise_pred, + is_step0=False): """ @param x_target_noisy: B,N,4,H,W @param index: index @@ -583,16 +841,21 @@ class SyncDDIMSampler: @return: """ device = x_target_noisy.device - B,N,_,H,W = x_target_noisy.shape + B, N, _, H, W = x_target_noisy.shape # apply noise - a_t = self.ddim_alphas[index].to(device).float().view(1,1,1,1,1) - a_prev = self.ddim_alphas_prev[index].to(device).float().view(1,1,1,1,1) - sqrt_one_minus_at = self.ddim_sqrt_one_minus_alphas[index].to(device).float().view(1,1,1,1,1) - sigma_t = self.ddim_sigmas[index].to(device).float().view(1,1,1,1,1) + a_t = self.ddim_alphas[index].to(device).float().view(1, 1, 1, 1, 1) + a_prev = self.ddim_alphas_prev[index].to(device).float().view( + 1, 1, 1, 1, 1) + sqrt_one_minus_at = self.ddim_sqrt_one_minus_alphas[index].to( + device).float().view(1, 1, 1, 1, 1) + sigma_t = self.ddim_sigmas[index].to(device).float().view( + 1, 1, 1, 1, 1) - pred_x0 = (x_target_noisy - sqrt_one_minus_at * noise_pred) / a_t.sqrt() - dir_xt = torch.clamp(1. - a_prev - sigma_t**2, min=1e-7).sqrt() * noise_pred + pred_x0 = (x_target_noisy + - sqrt_one_minus_at * noise_pred) / a_t.sqrt() + dir_xt = torch.clamp( + 1. - a_prev - sigma_t**2, min=1e-7).sqrt() * noise_pred x_prev = a_prev.sqrt() * pred_x0 + dir_xt if not is_step0: noise = sigma_t * torch.randn_like(x_target_noisy) @@ -600,7 +863,15 @@ class SyncDDIMSampler: return x_prev @torch.no_grad() - def denoise_apply(self, x_target_noisy, input_info, clip_embed, time_steps, index, unconditional_scale, batch_view_num=1, is_step0=False): + def denoise_apply(self, + x_target_noisy, + input_info, + clip_embed, + time_steps, + index, + unconditional_scale, + batch_view_num=1, + is_step0=False): """ @param x_target_noisy: B,N,4,H,W @param input_info: @@ -616,32 +887,50 @@ class SyncDDIMSampler: B, N, C, H, W = x_target_noisy.shape # construct source data - v_embed = self.model.get_viewpoint_embedding(B, elevation_input) # B,N,v_dim + v_embed = self.model.get_viewpoint_embedding( + B, elevation_input) # B,N,v_dim t_embed = self.model.embed_time(time_steps) # B,t_dim - spatial_volume = self.model.spatial_volume.construct_spatial_volume(x_target_noisy, t_embed, v_embed, self.model.poses, self.model.Ks) + spatial_volume = self.model.spatial_volume.construct_spatial_volume( + x_target_noisy, t_embed, v_embed, self.model.poses, self.model.Ks) e_t = [] - target_indices = torch.arange(N) # N + target_indices = torch.arange(N) # N for ni in range(0, N, batch_view_num): x_target_noisy_ = x_target_noisy[:, ni:ni + batch_view_num] VN = x_target_noisy_.shape[1] - x_target_noisy_ = x_target_noisy_.reshape(B*VN,C,H,W) + x_target_noisy_ = x_target_noisy_.reshape(B * VN, C, H, W) time_steps_ = repeat_to_batch(time_steps, B, VN) - target_indices_ = target_indices[ni:ni+batch_view_num].unsqueeze(0).repeat(B,1) - clip_embed_, volume_feats_, x_concat_ = self.model.get_target_view_feats(x_input, spatial_volume, clip_embed, t_embed, v_embed, target_indices_) - if unconditional_scale!=1.0: - noise = self.model.model.predict_with_unconditional_scale(x_target_noisy_, time_steps_, clip_embed_, volume_feats_, x_concat_, unconditional_scale) + target_indices_ = target_indices[ni:ni + batch_view_num].unsqueeze( + 0).repeat(B, 1) + clip_embed_, volume_feats_, x_concat_ = self.model.get_target_view_feats( + x_input, spatial_volume, clip_embed, t_embed, v_embed, + target_indices_) + if unconditional_scale != 1.0: + noise = self.model.model.predict_with_unconditional_scale( + x_target_noisy_, time_steps_, clip_embed_, volume_feats_, + x_concat_, unconditional_scale) else: - noise = self.model.model(x_target_noisy_, time_steps_, clip_embed_, volume_feats_, x_concat_, is_train=False) - e_t.append(noise.view(B,VN,4,H,W)) + noise = self.model.model( + x_target_noisy_, + time_steps_, + clip_embed_, + volume_feats_, + x_concat_, + is_train=False) + e_t.append(noise.view(B, VN, 4, H, W)) e_t = torch.cat(e_t, 1) x_prev = self.denoise_apply_impl(x_target_noisy, index, e_t, is_step0) return x_prev @torch.no_grad() - def sample(self, input_info, clip_embed, unconditional_scale=1.0, log_every_t=50, batch_view_num=1): + def sample(self, + input_info, + clip_embed, + unconditional_scale=1.0, + log_every_t=50, + batch_view_num=1): """ @param input_info: x, elevation @param clip_embed: B,M,768 @@ -650,7 +939,7 @@ class SyncDDIMSampler: @param batch_view_num: @return: """ - print(f"unconditional scale {unconditional_scale:.1f}") + print(f'unconditional scale {unconditional_scale:.1f}') C, H, W = 4, self.latent_size, self.latent_size B = clip_embed.shape[0] N = self.model.view_num @@ -664,10 +953,21 @@ class SyncDDIMSampler: iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) for i, step in enumerate(iterator): - index = total_steps - i - 1 # index in ddim state - time_steps = torch.full((B,), step, device=device, dtype=torch.long) - x_target_noisy = self.denoise_apply(x_target_noisy, input_info, clip_embed, time_steps, index, unconditional_scale, batch_view_num=batch_view_num, is_step0=index==0) + index = total_steps - i - 1 # index in ddim state + time_steps = torch.full((B, ), + step, + device=device, + dtype=torch.long) + x_target_noisy = self.denoise_apply( + x_target_noisy, + input_info, + clip_embed, + time_steps, + index, + unconditional_scale, + batch_view_num=batch_view_num, + is_step0=index == 0) if index % log_every_t == 0 or index == total_steps - 1: intermediates['x_inter'].append(x_target_noisy) - return x_target_noisy, intermediates \ No newline at end of file + return x_target_noisy, intermediates diff --git a/modelscope/models/cv/image_to_3d/ldm/models/diffusion/sync_dreamer_attention.py b/modelscope/models/cv/image_to_3d/ldm/models/diffusion/sync_dreamer_attention.py index 866f8eb7..f1ad8b66 100644 --- a/modelscope/models/cv/image_to_3d/ldm/models/diffusion/sync_dreamer_attention.py +++ b/modelscope/models/cv/image_to_3d/ldm/models/diffusion/sync_dreamer_attention.py @@ -1,17 +1,27 @@ import torch import torch.nn as nn -from modelscope.models.cv.image_to_3d.ldm.modules.attention import default, zero_module, checkpoint -from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.openaimodel import UNetModel -from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.util import timestep_embedding +from modelscope.models.cv.image_to_3d.ldm.modules.attention import ( # no qa + checkpoint, default, zero_module) +from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.openaimodel import \ + UNetModel +from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.util import \ + timestep_embedding + class DepthAttention(nn.Module): - def __init__(self, query_dim, context_dim, heads, dim_head, output_bias=True): + + def __init__(self, + query_dim, + context_dim, + heads, + dim_head, + output_bias=True): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) - self.scale = dim_head ** -0.5 + self.scale = dim_head**-0.5 self.heads = heads self.dim_head = dim_head @@ -34,21 +44,27 @@ class DepthAttention(nn.Module): b, _, h, w = x.shape b, _, d, h, w = context.shape - q = self.to_q(x).reshape(b,hn,hd,h,w) # b,t,h,w - k = self.to_k(context).reshape(b,hn,hd,d,h,w) # b,t,d,h,w - v = self.to_v(context).reshape(b,hn,hd,d,h,w) # b,t,d,h,w + q = self.to_q(x).reshape(b, hn, hd, h, w) # b,t,h,w + k = self.to_k(context).reshape(b, hn, hd, d, h, w) # b,t,d,h,w + v = self.to_v(context).reshape(b, hn, hd, d, h, w) # b,t,d,h,w - sim = torch.sum(q.unsqueeze(3) * k, 2) * self.scale # b,hn,d,h,w + sim = torch.sum(q.unsqueeze(3) * k, 2) * self.scale # b,hn,d,h,w attn = sim.softmax(dim=2) # b,hn,hd,d,h,w * b,hn,1,d,h,w - out = torch.sum(v * attn.unsqueeze(2), 3) # b,hn,hd,h,w - out = out.reshape(b,hn*hd,h,w) + out = torch.sum(v * attn.unsqueeze(2), 3) # b,hn,hd,h,w + out = out.reshape(b, hn * hd, h, w) return self.to_out(out) class DepthTransformer(nn.Module): - def __init__(self, dim, n_heads, d_head, context_dim=None, checkpoint=True): + + def __init__(self, + dim, + n_heads, + d_head, + context_dim=None, + checkpoint=True): super().__init__() inner_dim = n_heads * d_head self.proj_in = nn.Sequential( @@ -57,11 +73,18 @@ class DepthTransformer(nn.Module): nn.SiLU(True), ) self.proj_context = nn.Sequential( - nn.Conv3d(context_dim, context_dim, 1, 1, bias=False), # no bias + nn.Conv3d(context_dim, context_dim, 1, 1, bias=False), # no bias nn.GroupNorm(8, context_dim), - nn.ReLU(True), # only relu, because we want input is 0, output is 0 + nn.ReLU( + True), # only relu, because we want input is 0, output is 0 ) - self.depth_attn = DepthAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, context_dim=context_dim, output_bias=False) # is a self-attention if not self.disable_self_attn + self.depth_attn = DepthAttention( + query_dim=inner_dim, + heads=n_heads, + dim_head=d_head, + context_dim=context_dim, + output_bias=False + ) # is a self-attention if not self.disable_self_attn self.proj_out = nn.Sequential( nn.GroupNorm(8, inner_dim), nn.ReLU(True), @@ -73,7 +96,8 @@ class DepthTransformer(nn.Module): self.checkpoint = checkpoint def forward(self, x, context=None): - return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + return checkpoint(self._forward, (x, context), self.parameters(), + self.checkpoint) def _forward(self, x, context): x_in = x @@ -85,38 +109,65 @@ class DepthTransformer(nn.Module): class DepthWiseAttention(UNetModel): - def __init__(self, volume_dims=(5,16,32,64), *args, **kwargs): + + def __init__(self, volume_dims=(5, 16, 32, 64), *args, **kwargs): super().__init__(*args, **kwargs) # num_heads = 4 model_channels = kwargs['model_channels'] channel_mult = kwargs['channel_mult'] - d0,d1,d2,d3 = volume_dims + d0, d1, d2, d3 = volume_dims # 4 - ch = model_channels*channel_mult[2] - self.middle_conditions = DepthTransformer(ch, 4, d3 // 2, context_dim=d3) + ch = model_channels * channel_mult[2] + self.middle_conditions = DepthTransformer( + ch, 4, d3 // 2, context_dim=d3) - self.output_conditions=nn.ModuleList() - self.output_b2c = {3:0,4:1,5:2,6:3,7:4,8:5,9:6,10:7,11:8} + self.output_conditions = nn.ModuleList() + self.output_b2c = { + 3: 0, + 4: 1, + 5: 2, + 6: 3, + 7: 4, + 8: 5, + 9: 6, + 10: 7, + 11: 8 + } # 8 - ch = model_channels*channel_mult[2] - self.output_conditions.append(DepthTransformer(ch, 4, d2 // 2, context_dim=d2)) # 0 - self.output_conditions.append(DepthTransformer(ch, 4, d2 // 2, context_dim=d2)) # 1 + ch = model_channels * channel_mult[2] + self.output_conditions.append( + DepthTransformer(ch, 4, d2 // 2, context_dim=d2)) # 0 + self.output_conditions.append( + DepthTransformer(ch, 4, d2 // 2, context_dim=d2)) # 1 # 16 - self.output_conditions.append(DepthTransformer(ch, 4, d1 // 2, context_dim=d1)) # 2 - ch = model_channels*channel_mult[1] - self.output_conditions.append(DepthTransformer(ch, 4, d1 // 2, context_dim=d1)) # 3 - self.output_conditions.append(DepthTransformer(ch, 4, d1 // 2, context_dim=d1)) # 4 + self.output_conditions.append( + DepthTransformer(ch, 4, d1 // 2, context_dim=d1)) # 2 + ch = model_channels * channel_mult[1] + self.output_conditions.append( + DepthTransformer(ch, 4, d1 // 2, context_dim=d1)) # 3 + self.output_conditions.append( + DepthTransformer(ch, 4, d1 // 2, context_dim=d1)) # 4 # 32 - self.output_conditions.append(DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 5 - ch = model_channels*channel_mult[0] - self.output_conditions.append(DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 6 - self.output_conditions.append(DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 7 - self.output_conditions.append(DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 8 + self.output_conditions.append( + DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 5 + ch = model_channels * channel_mult[0] + self.output_conditions.append( + DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 6 + self.output_conditions.append( + DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 7 + self.output_conditions.append( + DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 8 - def forward(self, x, timesteps=None, context=None, source_dict=None, **kwargs): + def forward(self, + x, + timesteps=None, + context=None, + source_dict=None, + **kwargs): hs = [] - t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + t_emb = timestep_embedding( + timesteps, self.model_channels, repeat_only=False) emb = self.time_embed(t_emb) h = x.type(self.dtype) @@ -138,5 +189,6 @@ class DepthWiseAttention(UNetModel): return self.out(h) def get_trainable_parameters(self): - paras = [para for para in self.middle_conditions.parameters()] + [para for para in self.output_conditions.parameters()] + paras = [para for para in self.middle_conditions.parameters() + ] + [para for para in self.output_conditions.parameters()] return paras diff --git a/modelscope/models/cv/image_to_3d/ldm/models/diffusion/sync_dreamer_network.py b/modelscope/models/cv/image_to_3d/ldm/models/diffusion/sync_dreamer_network.py index c03b3ddf..9b3d6616 100644 --- a/modelscope/models/cv/image_to_3d/ldm/models/diffusion/sync_dreamer_network.py +++ b/modelscope/models/cv/image_to_3d/ldm/models/diffusion/sync_dreamer_network.py @@ -1,10 +1,15 @@ import torch import torch.nn as nn + class Image2DResBlockWithTV(nn.Module): + def __init__(self, dim, tdim, vdim): super().__init__() - norm = lambda c: nn.GroupNorm(8, c) + + def norm(c): + return nn.GroupNorm(8, c) + self.time_embed = nn.Conv2d(tdim, dim, 1, 1) self.view_embed = nn.Conv2d(vdim, dim, 1, 1) self.conv = nn.Sequential( @@ -17,22 +22,28 @@ class Image2DResBlockWithTV(nn.Module): ) def forward(self, x, t, v): - return x+self.conv(x+self.time_embed(t)+self.view_embed(v)) + return x + self.conv(x + self.time_embed(t) + self.view_embed(v)) class NoisyTargetViewEncoder(nn.Module): - def __init__(self, time_embed_dim, viewpoint_dim, run_dim=16, output_dim=8): + + def __init__(self, + time_embed_dim, + viewpoint_dim, + run_dim=16, + output_dim=8): super().__init__() self.init_conv = nn.Conv2d(4, run_dim, 3, 1, 1) - self.out_conv0 = Image2DResBlockWithTV(run_dim, time_embed_dim, viewpoint_dim) - self.out_conv1 = Image2DResBlockWithTV(run_dim, time_embed_dim, viewpoint_dim) - self.out_conv2 = Image2DResBlockWithTV(run_dim, time_embed_dim, viewpoint_dim) + self.out_conv0 = Image2DResBlockWithTV(run_dim, time_embed_dim, + viewpoint_dim) + self.out_conv1 = Image2DResBlockWithTV(run_dim, time_embed_dim, + viewpoint_dim) + self.out_conv2 = Image2DResBlockWithTV(run_dim, time_embed_dim, + viewpoint_dim) self.final_out = nn.Sequential( - nn.GroupNorm(8, run_dim), - nn.SiLU(True), - nn.Conv2d(run_dim, output_dim, 3, 1, 1) - ) + nn.GroupNorm(8, run_dim), nn.SiLU(True), + nn.Conv2d(run_dim, output_dim, 3, 1, 1)) def forward(self, x, t, v): B, DT = t.shape @@ -47,23 +58,39 @@ class NoisyTargetViewEncoder(nn.Module): x = self.final_out(x) return x + class SpatialUpTimeBlock(nn.Module): + def __init__(self, x_in_dim, t_in_dim, out_dim): super().__init__() - norm_act = lambda c: nn.GroupNorm(8, c) + + def norm_act(c): + return nn.GroupNorm(8, c) + self.t_conv = nn.Conv3d(t_in_dim, x_in_dim, 1, 1) # 16 self.norm = norm_act(x_in_dim) self.silu = nn.SiLU(True) - self.conv = nn.ConvTranspose3d(x_in_dim, out_dim, kernel_size=3, padding=1, output_padding=1, stride=2) + self.conv = nn.ConvTranspose3d( + x_in_dim, + out_dim, + kernel_size=3, + padding=1, + output_padding=1, + stride=2) def forward(self, x, t): x = x + self.t_conv(t) return self.conv(self.silu(self.norm(x))) + class SpatialTimeBlock(nn.Module): + def __init__(self, x_in_dim, t_in_dim, out_dim, stride): super().__init__() - norm_act = lambda c: nn.GroupNorm(8, c) + + def norm_act(c): + return nn.GroupNorm(8, c) + self.t_conv = nn.Conv3d(t_in_dim, x_in_dim, 1, 1) # 16 self.bn = norm_act(x_in_dim) self.silu = nn.SiLU(True) @@ -73,61 +100,68 @@ class SpatialTimeBlock(nn.Module): x = x + self.t_conv(t) return self.conv(self.silu(self.bn(x))) + class SpatialTime3DNet(nn.Module): - def __init__(self, time_dim=256, input_dim=128, dims=(32, 64, 128, 256)): - super().__init__() - d0, d1, d2, d3 = dims - dt = time_dim - self.init_conv = nn.Conv3d(input_dim, d0, 3, 1, 1) # 32 - self.conv0 = SpatialTimeBlock(d0, dt, d0, stride=1) + def __init__(self, time_dim=256, input_dim=128, dims=(32, 64, 128, 256)): + super().__init__() + d0, d1, d2, d3 = dims + dt = time_dim - self.conv1 = SpatialTimeBlock(d0, dt, d1, stride=2) - self.conv2_0 = SpatialTimeBlock(d1, dt, d1, stride=1) - self.conv2_1 = SpatialTimeBlock(d1, dt, d1, stride=1) + self.init_conv = nn.Conv3d(input_dim, d0, 3, 1, 1) # 32 + self.conv0 = SpatialTimeBlock(d0, dt, d0, stride=1) - self.conv3 = SpatialTimeBlock(d1, dt, d2, stride=2) - self.conv4_0 = SpatialTimeBlock(d2, dt, d2, stride=1) - self.conv4_1 = SpatialTimeBlock(d2, dt, d2, stride=1) + self.conv1 = SpatialTimeBlock(d0, dt, d1, stride=2) + self.conv2_0 = SpatialTimeBlock(d1, dt, d1, stride=1) + self.conv2_1 = SpatialTimeBlock(d1, dt, d1, stride=1) - self.conv5 = SpatialTimeBlock(d2, dt, d3, stride=2) - self.conv6_0 = SpatialTimeBlock(d3, dt, d3, stride=1) - self.conv6_1 = SpatialTimeBlock(d3, dt, d3, stride=1) + self.conv3 = SpatialTimeBlock(d1, dt, d2, stride=2) + self.conv4_0 = SpatialTimeBlock(d2, dt, d2, stride=1) + self.conv4_1 = SpatialTimeBlock(d2, dt, d2, stride=1) - self.conv7 = SpatialUpTimeBlock(d3, dt, d2) - self.conv8 = SpatialUpTimeBlock(d2, dt, d1) - self.conv9 = SpatialUpTimeBlock(d1, dt, d0) + self.conv5 = SpatialTimeBlock(d2, dt, d3, stride=2) + self.conv6_0 = SpatialTimeBlock(d3, dt, d3, stride=1) + self.conv6_1 = SpatialTimeBlock(d3, dt, d3, stride=1) - def forward(self, x, t): - B, C = t.shape - t = t.view(B, C, 1, 1, 1) + self.conv7 = SpatialUpTimeBlock(d3, dt, d2) + self.conv8 = SpatialUpTimeBlock(d2, dt, d1) + self.conv9 = SpatialUpTimeBlock(d1, dt, d0) - x = self.init_conv(x) - conv0 = self.conv0(x, t) + def forward(self, x, t): + B, C = t.shape + t = t.view(B, C, 1, 1, 1) - x = self.conv1(conv0, t) - x = self.conv2_0(x, t) - conv2 = self.conv2_1(x, t) + x = self.init_conv(x) + conv0 = self.conv0(x, t) - x = self.conv3(conv2, t) - x = self.conv4_0(x, t) - conv4 = self.conv4_1(x, t) + x = self.conv1(conv0, t) + x = self.conv2_0(x, t) + conv2 = self.conv2_1(x, t) - x = self.conv5(conv4, t) - x = self.conv6_0(x, t) - x = self.conv6_1(x, t) + x = self.conv3(conv2, t) + x = self.conv4_0(x, t) + conv4 = self.conv4_1(x, t) + + x = self.conv5(conv4, t) + x = self.conv6_0(x, t) + x = self.conv6_1(x, t) + + x = conv4 + self.conv7(x, t) + x = conv2 + self.conv8(x, t) + x = conv0 + self.conv9(x, t) + return x - x = conv4 + self.conv7(x, t) - x = conv2 + self.conv8(x, t) - x = conv0 + self.conv9(x, t) - return x class FrustumTVBlock(nn.Module): + def __init__(self, x_dim, t_dim, v_dim, out_dim, stride): super().__init__() - norm_act = lambda c: nn.GroupNorm(8, c) - self.t_conv = nn.Conv3d(t_dim, x_dim, 1, 1) # 16 - self.v_conv = nn.Conv3d(v_dim, x_dim, 1, 1) # 16 + + def norm_act(c): + return nn.GroupNorm(8, c) + + self.t_conv = nn.Conv3d(t_dim, x_dim, 1, 1) # 16 + self.v_conv = nn.Conv3d(v_dim, x_dim, 1, 1) # 16 self.bn = norm_act(x_dim) self.silu = nn.SiLU(True) self.conv = nn.Conv3d(x_dim, out_dim, 3, stride=stride, padding=1) @@ -136,24 +170,37 @@ class FrustumTVBlock(nn.Module): x = x + self.t_conv(t) + self.v_conv(v) return self.conv(self.silu(self.bn(x))) + class FrustumTVUpBlock(nn.Module): + def __init__(self, x_dim, t_dim, v_dim, out_dim): super().__init__() - norm_act = lambda c: nn.GroupNorm(8, c) - self.t_conv = nn.Conv3d(t_dim, x_dim, 1, 1) # 16 - self.v_conv = nn.Conv3d(v_dim, x_dim, 1, 1) # 16 + + def norm_act(c): + return nn.GroupNorm(8, c) + + self.t_conv = nn.Conv3d(t_dim, x_dim, 1, 1) # 16 + self.v_conv = nn.Conv3d(v_dim, x_dim, 1, 1) # 16 self.norm = norm_act(x_dim) self.silu = nn.SiLU(True) - self.conv = nn.ConvTranspose3d(x_dim, out_dim, kernel_size=3, padding=1, output_padding=1, stride=2) + self.conv = nn.ConvTranspose3d( + x_dim, + out_dim, + kernel_size=3, + padding=1, + output_padding=1, + stride=2) def forward(self, x, t, v): x = x + self.t_conv(t) + self.v_conv(v) return self.conv(self.silu(self.norm(x))) + class FrustumTV3DNet(nn.Module): + def __init__(self, in_dim, t_dim, v_dim, dims=(32, 64, 128, 256)): super().__init__() - self.conv0 = nn.Conv3d(in_dim, dims[0], 3, 1, 1) # 32 + self.conv0 = nn.Conv3d(in_dim, dims[0], 3, 1, 1) # 32 self.conv1 = FrustumTVBlock(dims[0], t_dim, v_dim, dims[1], 2) self.conv2 = FrustumTVBlock(dims[1], t_dim, v_dim, dims[1], 1) @@ -169,10 +216,10 @@ class FrustumTV3DNet(nn.Module): self.up2 = FrustumTVUpBlock(dims[1], t_dim, v_dim, dims[0]) def forward(self, x, t, v): - B,DT = t.shape - t = t.view(B,DT,1,1,1) - B,DV = v.shape - v = v.view(B,DV,1,1,1) + B, DT = t.shape + t = t.view(B, DT, 1, 1, 1) + B, DV = v.shape + v = v.view(B, DV, 1, 1, 1) b, _, d, h, w = x.shape x0 = self.conv0(x) @@ -183,4 +230,4 @@ class FrustumTV3DNet(nn.Module): x2 = self.up0(x3, t, v) + x2 x1 = self.up1(x2, t, v) + x1 x0 = self.up2(x1, t, v) + x0 - return {w: x0, w//2: x1, w//4: x2, w//8: x3} + return {w: x0, w // 2: x1, w // 4: x2, w // 8: x3} diff --git a/modelscope/models/cv/image_to_3d/ldm/models/diffusion/sync_dreamer_utils.py b/modelscope/models/cv/image_to_3d/ldm/models/diffusion/sync_dreamer_utils.py index c401c745..e7f2921f 100644 --- a/modelscope/models/cv/image_to_3d/ldm/models/diffusion/sync_dreamer_utils.py +++ b/modelscope/models/cv/image_to_3d/ldm/models/diffusion/sync_dreamer_utils.py @@ -10,13 +10,13 @@ def project_and_normalize(ref_grid, src_proj, length): @param length: int @return: b, n, 2 """ - src_grid = src_proj[:, :3, :3] @ ref_grid + src_proj[:, :3, 3:] # b 3 n + src_grid = src_proj[:, :3, :3] @ ref_grid + src_proj[:, :3, 3:] # b 3 n div_val = src_grid[:, -1:] - div_val[div_val<1e-4] = 1e-4 - src_grid = src_grid[:, :2] / div_val # divide by depth (b, 2, n) - src_grid[:, 0] = src_grid[:, 0]/((length - 1) / 2) - 1 # scale to -1~1 - src_grid[:, 1] = src_grid[:, 1]/((length - 1) / 2) - 1 # scale to -1~1 - src_grid = src_grid.permute(0, 2, 1) # (b, n, 2) + div_val[div_val < 1e-4] = 1e-4 + src_grid = src_grid[:, :2] / div_val # divide by depth (b, 2, n) + src_grid[:, 0] = src_grid[:, 0] / ((length - 1) / 2) - 1 # scale to -1~1 + src_grid[:, 1] = src_grid[:, 1] / ((length - 1) / 2) - 1 # scale to -1~1 + src_grid = src_grid.permute(0, 2, 1) # (b, n, 2) return src_grid @@ -29,38 +29,55 @@ def construct_project_matrix(x_ratio, y_ratio, Ks, poses): @return: """ rfn = Ks.shape[0] - scale_m = torch.tensor([x_ratio, y_ratio, 1.0], dtype=torch.float32, device=Ks.device) + scale_m = torch.tensor([x_ratio, y_ratio, 1.0], + dtype=torch.float32, + device=Ks.device) scale_m = torch.diag(scale_m) ref_prj = scale_m[None, :, :] @ Ks @ poses # rfn,3,4 - pad_vals = torch.zeros([rfn, 1, 4], dtype=torch.float32, device=ref_prj.device) + pad_vals = torch.zeros([rfn, 1, 4], + dtype=torch.float32, + device=ref_prj.device) pad_vals[:, :, 3] = 1.0 ref_prj = torch.cat([ref_prj, pad_vals], 1) # rfn,4,4 return ref_prj + def get_warp_coordinates(volume_xyz, warp_size, input_size, Ks, warp_pose): B, _, D, H, W = volume_xyz.shape ratio = warp_size / input_size - warp_proj = construct_project_matrix(ratio, ratio, Ks, warp_pose) # B,4,4 - warp_coords = project_and_normalize(volume_xyz.view(B,3,D*H*W), warp_proj, warp_size).view(B, D, H, W, 2) + warp_proj = construct_project_matrix(ratio, ratio, Ks, warp_pose) # B,4,4 + warp_coords = project_and_normalize( + volume_xyz.view(B, 3, D * H * W), warp_proj, + warp_size).view(B, D, H, W, 2) return warp_coords -def create_target_volume(depth_size, volume_size, input_image_size, pose_target, K, near=None, far=None): +def create_target_volume(depth_size, + volume_size, + input_image_size, + pose_target, + K, + near=None, + far=None): device, dtype = pose_target.device, pose_target.dtype # compute a depth range on the unit sphere H, W, D, B = volume_size, volume_size, depth_size, pose_target.shape[0] - if near is not None and far is not None : + if near is not None and far is not None: # near, far b,1,h,w - depth_values = torch.linspace(0, 1, steps=depth_size).to(near.device).to(near.dtype) # d - depth_values = depth_values.view(1, D, 1, 1) # 1,d,1,1 - depth_values = depth_values * (far - near) + near # b d h w + depth_values = torch.linspace( + 0, 1, steps=depth_size).to(near.device).to(near.dtype) # d + depth_values = depth_values.view(1, D, 1, 1) # 1,d,1,1 + depth_values = depth_values * (far - near) + near # b d h w depth_values = depth_values.view(B, 1, D, H * W) else: - near, far = near_far_from_unit_sphere_using_camera_poses(pose_target) # b 1 - depth_values = torch.linspace(0, 1, steps=depth_size).to(near.device).to(near.dtype) # d - depth_values = depth_values[None,:,None] * (far[:,None,:] - near[:,None,:]) + near[:,None,:] # b d 1 - depth_values = depth_values.view(B, 1, D, 1).expand(B, 1, D, H*W) + near, far = near_far_from_unit_sphere_using_camera_poses( + pose_target) # b 1 + depth_values = torch.linspace( + 0, 1, steps=depth_size).to(near.device).to(near.dtype) # d + depth_values = depth_values[None, :, None] * ( + far[:, None, :] - near[:, None, :]) + near[:, None, :] # b d 1 + depth_values = depth_values.view(B, 1, D, 1).expand(B, 1, D, H * W) ratio = volume_size / input_image_size @@ -68,20 +85,28 @@ def create_target_volume(depth_size, volume_size, input_image_size, pose_target, # H, W, D, B = volume_size, volume_size, depth_values.shape[1], depth_values.shape[0] # creat mesh grid: note reference also means target - ref_grid = create_meshgrid(H, W, normalized_coordinates=False) # (1, H, W, 2) + ref_grid = create_meshgrid( + H, W, normalized_coordinates=False) # (1, H, W, 2) ref_grid = ref_grid.to(device).to(dtype) - ref_grid = ref_grid.permute(0, 3, 1, 2) # (1, 2, H, W) - ref_grid = ref_grid.reshape(1, 2, H*W) # (1, 2, H*W) - ref_grid = ref_grid.expand(B, -1, -1) # (B, 2, H*W) - ref_grid = torch.cat((ref_grid, torch.ones(B, 1, H*W, dtype=ref_grid.dtype, device=ref_grid.device)), dim=1) # (B, 3, H*W) + ref_grid = ref_grid.permute(0, 3, 1, 2) # (1, 2, H, W) + ref_grid = ref_grid.reshape(1, 2, H * W) # (1, 2, H*W) + ref_grid = ref_grid.expand(B, -1, -1) # (B, 2, H*W) + ref_grid = torch.cat( + (ref_grid, + torch.ones(B, 1, H * W, dtype=ref_grid.dtype, + device=ref_grid.device)), + dim=1) # (B, 3, H*W) ref_grid = ref_grid.unsqueeze(2) * depth_values # (B, 3, D, H*W) # unproject to space and transfer to world coordinates. Ks = K - ref_proj = construct_project_matrix(ratio, ratio, Ks, pose_target) # B,4,4 - ref_proj_inv = torch.inverse(ref_proj) # B,4,4 - ref_grid = ref_proj_inv[:,:3,:3] @ ref_grid.view(B,3,D*H*W) + ref_proj_inv[:,:3,3:] # B,3,3 @ B,3,DHW + B,3,1 => B,3,DHW - return ref_grid.reshape(B,3,D,H,W), depth_values.view(B,1,D,H,W) + ref_proj = construct_project_matrix(ratio, ratio, Ks, pose_target) # B,4,4 + ref_proj_inv = torch.inverse(ref_proj) # B,4,4 + ref_grid = ref_proj_inv[:, :3, :3] @ ref_grid.view( + B, 3, D * H + * W) + ref_proj_inv[:, :3, 3:] # B,3,3 @ B,3,DHW + B,3,1 => B,3,DHW + return ref_grid.reshape(B, 3, D, H, W), depth_values.view(B, 1, D, H, W) + def near_far_from_unit_sphere_using_camera_poses(camera_poses): """ @@ -90,14 +115,16 @@ def near_far_from_unit_sphere_using_camera_poses(camera_poses): near: b,1 far: b,1 """ - R_w2c = camera_poses[..., :3, :3] # b 3 3 - t_w2c = camera_poses[..., :3, 3:] # b 3 1 - camera_origin = -R_w2c.permute(0,2,1) @ t_w2c # b 3 1 + R_w2c = camera_poses[..., :3, :3] # b 3 3 + t_w2c = camera_poses[..., :3, 3:] # b 3 1 + camera_origin = -R_w2c.permute(0, 2, 1) @ t_w2c # b 3 1 # R_w2c.T @ (0,0,1) = z_dir - camera_orient = R_w2c.permute(0,2,1)[...,:3,2:3] # b 3 1 - camera_origin, camera_orient = camera_origin[...,0], camera_orient[..., 0] # b 3 - a = torch.sum(camera_orient ** 2, dim=-1, keepdim=True) # b 1 - b = -torch.sum(camera_orient * camera_origin, dim=-1, keepdim=True) # b 1 - mid = b / a # b 1 + camera_orient = R_w2c.permute(0, 2, 1)[..., :3, 2:3] # b 3 1 + camera_origin, camera_orient = camera_origin[..., + 0], camera_orient[..., + 0] # b 3 + a = torch.sum(camera_orient**2, dim=-1, keepdim=True) # b 1 + b = -torch.sum(camera_orient * camera_origin, dim=-1, keepdim=True) # b 1 + mid = b / a # b 1 near, far = mid - 1.0, mid + 1.0 - return near, far \ No newline at end of file + return near, far diff --git a/modelscope/models/cv/image_to_3d/ldm/modules/attention.py b/modelscope/models/cv/image_to_3d/ldm/modules/attention.py index 4e33d0d8..aeab0a06 100644 --- a/modelscope/models/cv/image_to_3d/ldm/modules/attention.py +++ b/modelscope/models/cv/image_to_3d/ldm/modules/attention.py @@ -1,11 +1,13 @@ -from inspect import isfunction import math +from inspect import isfunction + import torch import torch.nn.functional as F -from torch import nn, einsum from einops import rearrange, repeat +from torch import einsum, nn -from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.util import checkpoint +from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.util import \ + checkpoint def exists(val): @@ -13,7 +15,7 @@ def exists(val): def uniq(arr): - return{el: True for el in arr}.keys() + return {el: True for el in arr}.keys() def default(val, d): @@ -35,6 +37,7 @@ def init_(tensor): # feedforward class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): super().__init__() self.proj = nn.Linear(dim_in, dim_out * 2) @@ -42,8 +45,11 @@ class GEGLU(nn.Module): def forward(self, x): x, gate = self.proj(x).chunk(2, dim=-1) return x * F.gelu(gate) + + # feedforward class ConvGEGLU(nn.Module): + def __init__(self, dim_in, dim_out): super().__init__() self.proj = nn.Conv2d(dim_in, dim_out * 2, 1, 1, 0) @@ -54,20 +60,16 @@ class ConvGEGLU(nn.Module): class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): super().__init__() inner_dim = int(dim * mult) dim_out = default(dim_out, dim) - project_in = nn.Sequential( - nn.Linear(dim, inner_dim), - nn.GELU() - ) if not glu else GEGLU(dim, inner_dim) + project_in = nn.Sequential(nn.Linear( + dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) - self.net = nn.Sequential( - project_in, - nn.Dropout(dropout), - nn.Linear(inner_dim, dim_out) - ) + self.net = nn.Sequential(project_in, nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out)) def forward(self, x): return self.net(x) @@ -83,54 +85,54 @@ def zero_module(module): def Normalize(in_channels): - return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + return torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): super().__init__() self.heads = heads hidden_dim = dim_head * heads - self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) self.to_out = nn.Conv2d(hidden_dim, dim, 1) def forward(self, x): b, c, h, w = x.shape qkv = self.to_qkv(x) - q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) - k = k.softmax(dim=-1) + q, k, v = rearrange( + qkv, + 'b (qkv heads c) h w -> qkv b heads c (h w)', + heads=self.heads, + qkv=3) + k = k.softmax(dim=-1) context = torch.einsum('bhdn,bhen->bhde', k, v) out = torch.einsum('bhde,bhdn->bhen', context, q) - out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) + out = rearrange( + out, + 'b heads c (h w) -> b (heads c) h w', + heads=self.heads, + h=h, + w=w) return self.to_out(out) class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): super().__init__() self.in_channels = in_channels self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.k = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.v = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0) def forward(self, x): h_ = x @@ -140,7 +142,7 @@ class SpatialSelfAttention(nn.Module): v = self.v(h_) # compute attention - b,c,h,w = q.shape + b, c, h, w = q.shape q = rearrange(q, 'b c h w -> b (h w) c') k = rearrange(k, 'b c h w -> b c (h w)') w_ = torch.einsum('bij,bjk->bik', q, k) @@ -155,16 +157,22 @@ class SpatialSelfAttention(nn.Module): h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) h_ = self.proj_out(h_) - return x+h_ + return x + h_ class CrossAttention(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + + def __init__(self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) - self.scale = dim_head ** -0.5 + self.scale = dim_head**-0.5 self.heads = heads self.to_q = nn.Linear(query_dim, inner_dim, bias=False) @@ -172,9 +180,7 @@ class CrossAttention(nn.Module): self.to_v = nn.Linear(context_dim, inner_dim, bias=False) self.to_out = nn.Sequential( - nn.Linear(inner_dim, query_dim), - nn.Dropout(dropout) - ) + nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) def forward(self, x, context=None, mask=None): h = self.heads @@ -184,12 +190,13 @@ class CrossAttention(nn.Module): k = self.to_k(context) v = self.to_v(context) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), + (q, k, v)) sim = einsum('b i d, b j d -> b i j', q, k) * self.scale if exists(mask): - mask = mask>0 + mask = mask > 0 mask = rearrange(mask, 'b ... -> b (...)') max_neg_value = -torch.finfo(sim.dtype).max mask = repeat(mask, 'b j -> (b h) () j', h=h) @@ -202,8 +209,15 @@ class CrossAttention(nn.Module): out = rearrange(out, '(b h) n d -> b n (h d)', h=h) return self.to_out(out) + class BasicSpatialTransformer(nn.Module): - def __init__(self, dim, n_heads, d_head, context_dim=None, checkpoint=True): + + def __init__(self, + dim, + n_heads, + d_head, + context_dim=None, + checkpoint=True): super().__init__() inner_dim = n_heads * d_head self.proj_in = nn.Sequential( @@ -212,7 +226,12 @@ class BasicSpatialTransformer(nn.Module): nn.GroupNorm(8, inner_dim), nn.ReLU(True), ) - self.attn = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, context_dim=context_dim) # is a self-attention if not self.disable_self_attn + self.attn = CrossAttention( + query_dim=inner_dim, + heads=n_heads, + dim_head=d_head, + context_dim=context_dim + ) # is a self-attention if not self.disable_self_attn self.out_conv = nn.Sequential( nn.GroupNorm(8, inner_dim), nn.ReLU(True), @@ -221,16 +240,18 @@ class BasicSpatialTransformer(nn.Module): self.proj_out = nn.Sequential( nn.GroupNorm(8, inner_dim), nn.ReLU(True), - zero_module(nn.Conv2d(inner_dim, dim, kernel_size=1, stride=1, padding=0)), + zero_module( + nn.Conv2d(inner_dim, dim, kernel_size=1, stride=1, padding=0)), ) self.checkpoint = checkpoint def forward(self, x, context=None): - return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + return checkpoint(self._forward, (x, context), self.parameters(), + self.checkpoint) def _forward(self, x, context): # input - b,_,h,w = x.shape + b, _, h, w = x.shape x_in = x x = self.proj_in(x) @@ -245,44 +266,64 @@ class BasicSpatialTransformer(nn.Module): x = self.proj_out(x) + x_in return x + class BasicTransformerBlock(nn.Module): - def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, disable_self_attn=False): + + def __init__(self, + dim, + n_heads, + d_head, + dropout=0., + context_dim=None, + gated_ff=True, + checkpoint=True, + disable_self_attn=False): super().__init__() self.disable_self_attn = disable_self_attn - self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, - context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn + self.attn1 = CrossAttention( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + context_dim=context_dim if self.disable_self_attn else + None) # is a self-attention if not self.disable_self_attn self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) - self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, - heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.attn2 = CrossAttention( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout) # is self-attn if context is none self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) self.norm3 = nn.LayerNorm(dim) self.checkpoint = checkpoint def forward(self, x, context=None): - return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + return checkpoint(self._forward, (x, context), self.parameters(), + self.checkpoint) def _forward(self, x, context=None): - x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x + x = self.attn1( + self.norm1(x), + context=context if self.disable_self_attn else None) + x x = self.attn2(self.norm2(x), context=context) + x x = self.ff(self.norm3(x)) + x return x + class ConvFeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): super().__init__() inner_dim = int(dim * mult) dim_out = default(dim_out, dim) project_in = nn.Sequential( nn.Conv2d(dim, inner_dim, 1, 1, 0), - nn.GELU() - ) if not glu else ConvGEGLU(dim, inner_dim) + nn.GELU()) if not glu else ConvGEGLU(dim, inner_dim) - self.net = nn.Sequential( - project_in, - nn.Dropout(dropout), - nn.Conv2d(inner_dim, dim_out, 1, 1, 0) - ) + self.net = nn.Sequential(project_in, nn.Dropout(dropout), + nn.Conv2d(inner_dim, dim_out, 1, 1, 0)) def forward(self, x): return self.net(x) @@ -296,31 +337,36 @@ class SpatialTransformer(nn.Module): Then apply standard transformer action. Finally, reshape to image """ - def __init__(self, in_channels, n_heads, d_head, - depth=1, dropout=0., context_dim=None, + + def __init__(self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0., + context_dim=None, disable_self_attn=False): super().__init__() self.in_channels = in_channels inner_dim = n_heads * d_head self.norm = Normalize(in_channels) - self.proj_in = nn.Conv2d(in_channels, - inner_dim, - kernel_size=1, - stride=1, - padding=0) + self.proj_in = nn.Conv2d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0) - self.transformer_blocks = nn.ModuleList( - [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim, - disable_self_attn=disable_self_attn) - for d in range(depth)] - ) + self.transformer_blocks = nn.ModuleList([ + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim, + disable_self_attn=disable_self_attn) for d in range(depth) + ]) - self.proj_out = zero_module(nn.Conv2d(inner_dim, - in_channels, - kernel_size=1, - stride=1, - padding=0)) + self.proj_out = zero_module( + nn.Conv2d( + inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) def forward(self, x, context=None): # note: if no context is given, cross-attention defaults to self-attention diff --git a/modelscope/models/cv/image_to_3d/ldm/modules/diffusionmodules/model.py b/modelscope/models/cv/image_to_3d/ldm/modules/diffusionmodules/model.py index 69d910bf..83780c98 100644 --- a/modelscope/models/cv/image_to_3d/ldm/modules/diffusionmodules/model.py +++ b/modelscope/models/cv/image_to_3d/ldm/modules/diffusionmodules/model.py @@ -1,12 +1,14 @@ # pytorch_diffusion + derived encoder decoder import math + +import numpy as np import torch import torch.nn as nn -import numpy as np from einops import rearrange +from modelscope.models.cv.image_to_3d.ldm.modules.attention import \ + LinearAttention from modelscope.models.cv.image_to_3d.ldm.util import instantiate_from_config -from modelscope.models.cv.image_to_3d.ldm.modules.attention import LinearAttention def get_timestep_embedding(timesteps, embedding_dim): @@ -26,53 +28,51 @@ def get_timestep_embedding(timesteps, embedding_dim): emb = timesteps.float()[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0,1,0,0)) + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) return emb def nonlinearity(x): # swish - return x*torch.sigmoid(x) + return x * torch.sigmoid(x) def Normalize(in_channels, num_groups=32): - return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + return torch.nn.GroupNorm( + num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): super().__init__() self.with_conv = with_conv if self.with_conv: - self.conv = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1) def forward(self, x): - x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = torch.nn.functional.interpolate( + x, scale_factor=2.0, mode='nearest') if self.with_conv: x = self.conv(x) return x class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): super().__init__() self.with_conv = with_conv if self.with_conv: # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=3, - stride=2, - padding=0) + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0) def forward(self, x): if self.with_conv: - pad = (0,1,0,1) - x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode='constant', value=0) x = self.conv(x) else: x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) @@ -80,8 +80,14 @@ class Downsample(nn.Module): class ResnetBlock(nn.Module): - def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, - dropout, temb_channels=512): + + def __init__(self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels @@ -89,34 +95,29 @@ class ResnetBlock(nn.Module): self.use_conv_shortcut = conv_shortcut self.norm1 = Normalize(in_channels) - self.conv1 = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv1 = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1) if temb_channels > 0: - self.temb_proj = torch.nn.Linear(temb_channels, - out_channels) + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) self.norm2 = Normalize(out_channels) self.dropout = torch.nn.Dropout(dropout) - self.conv2 = torch.nn.Conv2d(out_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv2 = torch.nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1) if self.in_channels != self.out_channels: if self.use_conv_shortcut: - self.conv_shortcut = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv_shortcut = torch.nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) else: - self.nin_shortcut = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=1, - stride=1, - padding=0) + self.nin_shortcut = torch.nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) def forward(self, x, temb): h = x @@ -125,7 +126,7 @@ class ResnetBlock(nn.Module): h = self.conv1(h) if temb is not None: - h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] h = self.norm2(h) h = nonlinearity(h) @@ -138,42 +139,31 @@ class ResnetBlock(nn.Module): else: x = self.nin_shortcut(x) - return x+h + return x + h class LinAttnBlock(LinearAttention): """to match AttnBlock usage""" + def __init__(self, in_channels): super().__init__(dim=in_channels, heads=1, dim_head=in_channels) class AttnBlock(nn.Module): + def __init__(self, in_channels): super().__init__() self.in_channels = in_channels self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.k = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.v = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0) def forward(self, x): h_ = x @@ -183,44 +173,61 @@ class AttnBlock(nn.Module): v = self.v(h_) # compute attention - b,c,h,w = q.shape - q = q.reshape(b,c,h*w) - q = q.permute(0,2,1) # b,hw,c - k = k.reshape(b,c,h*w) # b,c,hw - w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] w_ = w_ * (int(c)**(-0.5)) w_ = torch.nn.functional.softmax(w_, dim=2) # attend to values - v = v.reshape(b,c,h*w) - w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) - h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - h_ = h_.reshape(b,c,h,w) + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm( + v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) h_ = self.proj_out(h_) - return x+h_ + return x + h_ -def make_attn(in_channels, attn_type="vanilla"): - assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' - print(f"making attention of type '{attn_type}' with {in_channels} in_channels") - if attn_type == "vanilla": +def make_attn(in_channels, attn_type='vanilla'): + assert attn_type in ['vanilla', 'linear', + 'none'], f'attn_type {attn_type} unknown' + print( + f"making attention of type '{attn_type}' with {in_channels} in_channels" + ) + if attn_type == 'vanilla': return AttnBlock(in_channels) - elif attn_type == "none": + elif attn_type == 'none': return nn.Identity(in_channels) else: return LinAttnBlock(in_channels) class Model(nn.Module): - def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, - resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): + + def __init__(self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + use_timestep=True, + use_linear_attn=False, + attn_type='vanilla'): super().__init__() - if use_linear_attn: attn_type = "linear" + if use_linear_attn: + attn_type = 'linear' self.ch = ch - self.temb_ch = self.ch*4 + self.temb_ch = self.ch * 4 self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution @@ -231,69 +238,70 @@ class Model(nn.Module): # timestep embedding self.temb = nn.Module() self.temb.dense = nn.ModuleList([ - torch.nn.Linear(self.ch, - self.temb_ch), - torch.nn.Linear(self.temb_ch, - self.temb_ch), + torch.nn.Linear(self.ch, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), ]) # downsampling - self.conv_in = torch.nn.Conv2d(in_channels, - self.ch, - kernel_size=3, - stride=1, - padding=1) + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1) curr_res = resolution - in_ch_mult = (1,)+tuple(ch_mult) + in_ch_mult = (1, ) + tuple(ch_mult) self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() - block_in = ch*in_ch_mult[i_level] - block_out = ch*ch_mult[i_level] + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks): - block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) down = nn.Module() down.block = block down.attn = attn - if i_level != self.num_resolutions-1: + if i_level != self.num_resolutions - 1: down.downsample = Downsample(block_in, resamp_with_conv) curr_res = curr_res // 2 self.down.append(down) # middle self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) # upsampling self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() - block_out = ch*ch_mult[i_level] - skip_in = ch*ch_mult[i_level] - for i_block in range(self.num_res_blocks+1): + block_out = ch * ch_mult[i_level] + skip_in = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): if i_block == self.num_res_blocks: - skip_in = ch*in_ch_mult[i_level] - block.append(ResnetBlock(in_channels=block_in+skip_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + skip_in = ch * in_ch_mult[i_level] + block.append( + ResnetBlock( + in_channels=block_in + skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) @@ -303,18 +311,15 @@ class Model(nn.Module): if i_level != 0: up.upsample = Upsample(block_in, resamp_with_conv) curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order + self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - out_ch, - kernel_size=3, - stride=1, - padding=1) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1) def forward(self, x, t=None, context=None): - #assert x.shape[2] == x.shape[3] == self.resolution + # assert x.shape[2] == x.shape[3] == self.resolution if context is not None: # assume aligned context, cat along channel axis x = torch.cat((x, context), dim=1) @@ -336,7 +341,7 @@ class Model(nn.Module): if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) hs.append(h) - if i_level != self.num_resolutions-1: + if i_level != self.num_resolutions - 1: hs.append(self.down[i_level].downsample(hs[-1])) # middle @@ -347,9 +352,9 @@ class Model(nn.Module): # upsampling for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks+1): - h = self.up[i_level].block[i_block]( - torch.cat([h, hs.pop()], dim=1), temb) + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], + dim=1), temb) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h) if i_level != 0: @@ -366,12 +371,26 @@ class Model(nn.Module): class Encoder(nn.Module): - def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, - resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", + + def __init__(self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + use_linear_attn=False, + attn_type='vanilla', **ignore_kwargs): super().__init__() - if use_linear_attn: attn_type = "linear" + if use_linear_attn: + attn_type = 'linear' self.ch = ch self.temb_ch = 0 self.num_resolutions = len(ch_mult) @@ -380,56 +399,58 @@ class Encoder(nn.Module): self.in_channels = in_channels # downsampling - self.conv_in = torch.nn.Conv2d(in_channels, - self.ch, - kernel_size=3, - stride=1, - padding=1) + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1) curr_res = resolution - in_ch_mult = (1,)+tuple(ch_mult) + in_ch_mult = (1, ) + tuple(ch_mult) self.in_ch_mult = in_ch_mult self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() - block_in = ch*in_ch_mult[i_level] - block_out = ch*ch_mult[i_level] + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks): - block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) down = nn.Module() down.block = block down.attn = attn - if i_level != self.num_resolutions-1: + if i_level != self.num_resolutions - 1: down.downsample = Downsample(block_in, resamp_with_conv) curr_res = curr_res // 2 self.down.append(down) # middle self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) # end self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - 2*z_channels if double_z else z_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv_out = torch.nn.Conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) def forward(self, x): # timestep embedding @@ -443,7 +464,7 @@ class Encoder(nn.Module): if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) hs.append(h) - if i_level != self.num_resolutions-1: + if i_level != self.num_resolutions - 1: hs.append(self.down[i_level].downsample(hs[-1])) # middle @@ -460,12 +481,27 @@ class Encoder(nn.Module): class Decoder(nn.Module): - def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, - resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, - attn_type="vanilla", **ignorekwargs): + + def __init__(self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + attn_type='vanilla', + **ignorekwargs): super().__init__() - if use_linear_attn: attn_type = "linear" + if use_linear_attn: + attn_type = 'linear' self.ch = ch self.temb_ch = 0 self.num_resolutions = len(ch_mult) @@ -476,43 +512,44 @@ class Decoder(nn.Module): self.tanh_out = tanh_out # compute in_ch_mult, block_in and curr_res at lowest res - in_ch_mult = (1,)+tuple(ch_mult) - block_in = ch*ch_mult[self.num_resolutions-1] - curr_res = resolution // 2**(self.num_resolutions-1) - self.z_shape = (1,z_channels,curr_res,curr_res) - print("Working with z of shape {} = {} dimensions.".format( + # in_ch_mult = (1, ) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2**(self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + print('Working with z of shape {} = {} dimensions.'.format( self.z_shape, np.prod(self.z_shape))) # z to block_in - self.conv_in = torch.nn.Conv2d(z_channels, - block_in, - kernel_size=3, - stride=1, - padding=1) + self.conv_in = torch.nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1) # middle self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) # upsampling self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() - block_out = ch*ch_mult[i_level] - for i_block in range(self.num_res_blocks+1): - block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) @@ -522,18 +559,15 @@ class Decoder(nn.Module): if i_level != 0: up.upsample = Upsample(block_in, resamp_with_conv) curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order + self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - out_ch, - kernel_size=3, - stride=1, - padding=1) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1) def forward(self, z): - #assert z.shape[1:] == self.z_shape[1:] + # assert z.shape[1:] == self.z_shape[1:] self.last_z_shape = z.shape # timestep embedding @@ -549,7 +583,7 @@ class Decoder(nn.Module): # upsampling for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks+1): + for i_block in range(self.num_res_blocks + 1): h = self.up[i_level].block[i_block](h, temb) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h) @@ -569,31 +603,37 @@ class Decoder(nn.Module): class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): super().__init__() - self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), - ResnetBlock(in_channels=in_channels, - out_channels=2 * in_channels, - temb_channels=0, dropout=0.0), - ResnetBlock(in_channels=2 * in_channels, - out_channels=4 * in_channels, - temb_channels=0, dropout=0.0), - ResnetBlock(in_channels=4 * in_channels, - out_channels=2 * in_channels, - temb_channels=0, dropout=0.0), - nn.Conv2d(2*in_channels, in_channels, 1), - Upsample(in_channels, with_conv=True)]) + self.model = nn.ModuleList([ + nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock( + in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, + dropout=0.0), + ResnetBlock( + in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, + dropout=0.0), + ResnetBlock( + in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, + dropout=0.0), + nn.Conv2d(2 * in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True) + ]) # end self.norm_out = Normalize(in_channels) - self.conv_out = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv_out = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1) def forward(self, x): for i, layer in enumerate(self.model): - if i in [1,2,3]: + if i in [1, 2, 3]: x = layer(x, None) else: x = layer(x) @@ -605,25 +645,34 @@ class SimpleDecoder(nn.Module): class UpsampleDecoder(nn.Module): - def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, - ch_mult=(2,2), dropout=0.0): + + def __init__(self, + in_channels, + out_channels, + ch, + num_res_blocks, + resolution, + ch_mult=(2, 2), + dropout=0.0): super().__init__() # upsampling self.temb_ch = 0 self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks block_in = in_channels - curr_res = resolution // 2 ** (self.num_resolutions - 1) + curr_res = resolution // 2**(self.num_resolutions - 1) self.res_blocks = nn.ModuleList() self.upsample_blocks = nn.ModuleList() for i_level in range(self.num_resolutions): res_block = [] block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks + 1): - res_block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + res_block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) block_in = block_out self.res_blocks.append(nn.ModuleList(res_block)) if i_level != self.num_resolutions - 1: @@ -632,11 +681,8 @@ class UpsampleDecoder(nn.Module): # end self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - out_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv_out = torch.nn.Conv2d( + block_in, out_channels, kernel_size=3, stride=1, padding=1) def forward(self, x): # upsampling @@ -653,35 +699,48 @@ class UpsampleDecoder(nn.Module): class LatentRescaler(nn.Module): - def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + + def __init__(self, + factor, + in_channels, + mid_channels, + out_channels, + depth=2): super().__init__() # residual block, interpolate, residual block self.factor = factor - self.conv_in = nn.Conv2d(in_channels, - mid_channels, - kernel_size=3, - stride=1, - padding=1) - self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, - out_channels=mid_channels, - temb_channels=0, - dropout=0.0) for _ in range(depth)]) + self.conv_in = nn.Conv2d( + in_channels, mid_channels, kernel_size=3, stride=1, padding=1) + self.res_block1 = nn.ModuleList([ + ResnetBlock( + in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth) + ]) self.attn = AttnBlock(mid_channels) - self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, - out_channels=mid_channels, - temb_channels=0, - dropout=0.0) for _ in range(depth)]) + self.res_block2 = nn.ModuleList([ + ResnetBlock( + in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth) + ]) - self.conv_out = nn.Conv2d(mid_channels, - out_channels, - kernel_size=1, - ) + self.conv_out = nn.Conv2d( + mid_channels, + out_channels, + kernel_size=1, + ) def forward(self, x): x = self.conv_in(x) for block in self.res_block1: x = block(x, None) - x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor)))) + x = torch.nn.functional.interpolate( + x, + size=(int(round(x.shape[2] * self.factor)), + int(round(x.shape[3] * self.factor)))) x = self.attn(x) for block in self.res_block2: x = block(x, None) @@ -690,17 +749,39 @@ class LatentRescaler(nn.Module): class MergedRescaleEncoder(nn.Module): - def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, - ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1): + + def __init__(self, + in_channels, + ch, + resolution, + out_ch, + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + ch_mult=(1, 2, 4, 8), + rescale_factor=1.0, + rescale_module_depth=1): super().__init__() intermediate_chn = ch * ch_mult[-1] - self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult, - z_channels=intermediate_chn, double_z=False, resolution=resolution, - attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, - out_ch=None) - self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn, - mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth) + self.encoder = Encoder( + in_channels=in_channels, + num_res_blocks=num_res_blocks, + ch=ch, + ch_mult=ch_mult, + z_channels=intermediate_chn, + double_z=False, + resolution=resolution, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + out_ch=None) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=intermediate_chn, + mid_channels=intermediate_chn, + out_channels=out_ch, + depth=rescale_module_depth) def forward(self, x): x = self.encoder(x) @@ -709,15 +790,38 @@ class MergedRescaleEncoder(nn.Module): class MergedRescaleDecoder(nn.Module): - def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8), - dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1): + + def __init__(self, + z_channels, + out_ch, + resolution, + num_res_blocks, + attn_resolutions, + ch, + ch_mult=(1, 2, 4, 8), + dropout=0.0, + resamp_with_conv=True, + rescale_factor=1.0, + rescale_module_depth=1): super().__init__() - tmp_chn = z_channels*ch_mult[-1] - self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout, - resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks, - ch_mult=ch_mult, resolution=resolution, ch=ch) - self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn, - out_channels=tmp_chn, depth=rescale_module_depth) + tmp_chn = z_channels * ch_mult[-1] + self.decoder = Decoder( + out_ch=out_ch, + z_channels=tmp_chn, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + in_channels=None, + num_res_blocks=num_res_blocks, + ch_mult=ch_mult, + resolution=resolution, + ch=ch) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=z_channels, + mid_channels=tmp_chn, + out_channels=tmp_chn, + depth=rescale_module_depth) def forward(self, x): x = self.rescaler(x) @@ -726,17 +830,34 @@ class MergedRescaleDecoder(nn.Module): class Upsampler(nn.Module): - def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + + def __init__(self, + in_size, + out_size, + in_channels, + out_channels, + ch_mult=2): super().__init__() assert out_size >= in_size - num_blocks = int(np.log2(out_size//in_size))+1 - factor_up = 1.+ (out_size % in_size) - print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}") - self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels, - out_channels=in_channels) - self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2, - attn_resolutions=[], in_channels=None, ch=in_channels, - ch_mult=[ch_mult for _ in range(num_blocks)]) + num_blocks = int(np.log2(out_size // in_size)) + 1 + factor_up = 1. + (out_size % in_size) + print( + f'Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}' + ) + self.rescaler = LatentRescaler( + factor=factor_up, + in_channels=in_channels, + mid_channels=2 * in_channels, + out_channels=in_channels) + self.decoder = Decoder( + out_ch=out_channels, + resolution=out_size, + z_channels=in_channels, + num_res_blocks=2, + attn_resolutions=[], + in_channels=None, + ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)]) def forward(self, x): x = self.rescaler(x) @@ -745,32 +866,39 @@ class Upsampler(nn.Module): class Resize(nn.Module): - def __init__(self, in_channels=None, learned=False, mode="bilinear"): + + def __init__(self, in_channels=None, learned=False, mode='bilinear'): super().__init__() self.with_conv = learned self.mode = mode if self.with_conv: - print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") + print( + f'Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode' + ) raise NotImplementedError() assert in_channels is not None # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=4, - stride=2, - padding=1) + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=4, stride=2, padding=1) def forward(self, x, scale_factor=1.0): - if scale_factor==1.0: + if scale_factor == 1.0: return x else: - x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) + x = torch.nn.functional.interpolate( + x, + mode=self.mode, + align_corners=False, + scale_factor=scale_factor) return x + class FirstStagePostProcessor(nn.Module): - def __init__(self, ch_mult:list, in_channels, - pretrained_model:nn.Module=None, + def __init__(self, + ch_mult: list, + in_channels, + pretrained_model: nn.Module = None, reshape=False, n_channels=None, dropout=0., @@ -788,22 +916,25 @@ class FirstStagePostProcessor(nn.Module): if n_channels is None: n_channels = self.pretrained_model.encoder.ch - self.proj_norm = Normalize(in_channels,num_groups=in_channels//2) - self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3, - stride=1,padding=1) + self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2) + self.proj = nn.Conv2d( + in_channels, n_channels, kernel_size=3, stride=1, padding=1) blocks = [] downs = [] ch_in = n_channels for m in ch_mult: - blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout)) + blocks.append( + ResnetBlock( + in_channels=ch_in, + out_channels=m * n_channels, + dropout=dropout)) ch_in = m * n_channels downs.append(Downsample(ch_in, with_conv=False)) self.model = nn.ModuleList(blocks) self.downsampler = nn.ModuleList(downs) - def instantiate_pretrained(self, config): model = instantiate_from_config(config) self.pretrained_model = model.eval() @@ -811,25 +942,23 @@ class FirstStagePostProcessor(nn.Module): for param in self.pretrained_model.parameters(): param.requires_grad = False - @torch.no_grad() - def encode_with_pretrained(self,x): + def encode_with_pretrained(self, x): c = self.pretrained_model.encode(x) if isinstance(c, DiagonalGaussianDistribution): c = c.mode() - return c + return c - def forward(self,x): + def forward(self, x): z_fs = self.encode_with_pretrained(x) z = self.proj_norm(z_fs) z = self.proj(z) z = nonlinearity(z) - for submodel, downmodel in zip(self.model,self.downsampler): - z = submodel(z,temb=None) + for submodel, downmodel in zip(self.model, self.downsampler): + z = submodel(z, temb=None) z = downmodel(z) if self.do_reshape: - z = rearrange(z,'b c h w -> b (h w) c') + z = rearrange(z, 'b c h w -> b (h w) c') return z - diff --git a/modelscope/models/cv/image_to_3d/ldm/modules/diffusionmodules/openaimodel.py b/modelscope/models/cv/image_to_3d/ldm/modules/diffusionmodules/openaimodel.py index 87e00645..5b6ac5fc 100644 --- a/modelscope/models/cv/image_to_3d/ldm/modules/diffusionmodules/openaimodel.py +++ b/modelscope/models/cv/image_to_3d/ldm/modules/diffusionmodules/openaimodel.py @@ -1,6 +1,6 @@ +import math from abc import abstractmethod from functools import partial -import math from typing import Iterable import numpy as np @@ -8,16 +8,11 @@ import torch as th import torch.nn as nn import torch.nn.functional as F +from modelscope.models.cv.image_to_3d.ldm.modules.attention import \ + SpatialTransformer from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.util import ( - checkpoint, - conv_nd, - linear, - avg_pool_nd, - zero_module, - normalization, - timestep_embedding, -) -from modelscope.models.cv.image_to_3d.ldm.modules.attention import SpatialTransformer + avg_pool_nd, checkpoint, conv_nd, linear, normalization, + timestep_embedding, zero_module) from modelscope.models.cv.image_to_3d.ldm.util import exists @@ -25,11 +20,12 @@ from modelscope.models.cv.image_to_3d.ldm.util import exists def convert_module_to_f16(x): pass + def convert_module_to_f32(x): pass -## go +# go class AttentionPool2d(nn.Module): """ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py @@ -43,7 +39,8 @@ class AttentionPool2d(nn.Module): output_dim: int = None, ): super().__init__() - self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.positional_embedding = nn.Parameter( + th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5) self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) self.num_heads = embed_dim // num_heads_channels @@ -98,37 +95,46 @@ class Upsample(nn.Module): upsampling occurs in the inner-two dimensions. """ - def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + def __init__(self, + channels, + use_conv, + dims=2, + out_channels=None, + padding=1): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.dims = dims if use_conv: - self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + self.conv = conv_nd( + dims, self.channels, self.out_channels, 3, padding=padding) def forward(self, x): assert x.shape[1] == self.channels if self.dims == 3: x = F.interpolate( - x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" - ) + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), + mode='nearest') else: - x = F.interpolate(x, scale_factor=2, mode="nearest") + x = F.interpolate(x, scale_factor=2, mode='nearest') if self.use_conv: x = self.conv(x) return x + class TransposedUpsample(nn.Module): 'Learned 2x upsampling without padding' + def __init__(self, channels, out_channels=None, ks=5): super().__init__() self.channels = channels self.out_channels = out_channels or channels - self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) + self.up = nn.ConvTranspose2d( + self.channels, self.out_channels, kernel_size=ks, stride=2) - def forward(self,x): + def forward(self, x): return self.up(x) @@ -141,7 +147,12 @@ class Downsample(nn.Module): downsampling occurs in the inner-two dimensions. """ - def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + def __init__(self, + channels, + use_conv, + dims=2, + out_channels=None, + padding=1): super().__init__() self.channels = channels self.out_channels = out_channels or channels @@ -150,8 +161,12 @@ class Downsample(nn.Module): stride = 2 if dims != 3 else (1, 2, 2) if use_conv: self.op = conv_nd( - dims, self.channels, self.out_channels, 3, stride=stride, padding=padding - ) + dims, + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding) else: assert self.channels == self.out_channels self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) @@ -220,7 +235,8 @@ class ResBlock(TimestepBlock): nn.SiLU(), linear( emb_channels, - 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + 2 * self.out_channels + if use_scale_shift_norm else self.out_channels, ), ) self.out_layers = nn.Sequential( @@ -228,18 +244,18 @@ class ResBlock(TimestepBlock): nn.SiLU(), nn.Dropout(p=dropout), zero_module( - conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) - ), + conv_nd( + dims, self.out_channels, self.out_channels, 3, padding=1)), ) if self.out_channels == channels: self.skip_connection = nn.Identity() elif use_conv: self.skip_connection = conv_nd( - dims, channels, self.out_channels, 3, padding=1 - ) + dims, channels, self.out_channels, 3, padding=1) else: - self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + self.skip_connection = conv_nd(dims, channels, self.out_channels, + 1) def forward(self, x, emb): """ @@ -248,10 +264,8 @@ class ResBlock(TimestepBlock): :param emb: an [N x emb_channels] Tensor of timestep embeddings. :return: an [N x C x ...] Tensor of outputs. """ - return checkpoint( - self._forward, (x, emb), self.parameters(), self.use_checkpoint - ) - + return checkpoint(self._forward, (x, emb), self.parameters(), + self.use_checkpoint) def _forward(self, x, emb): if self.updown: @@ -265,7 +279,7 @@ class ResBlock(TimestepBlock): emb_out = self.emb_layers(emb).type(h.dtype) while len(emb_out.shape) < len(h.shape): emb_out = emb_out[..., None] - if self.use_scale_shift_norm: # False + if self.use_scale_shift_norm: # False out_norm, out_rest = self.out_layers[0], self.out_layers[1:] scale, shift = th.chunk(emb_out, 2, dim=1) h = out_norm(h) * (1 + scale) + shift @@ -298,7 +312,7 @@ class AttentionBlock(nn.Module): else: assert ( channels % num_head_channels == 0 - ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + ), f'q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}' self.num_heads = channels // num_head_channels self.use_checkpoint = use_checkpoint self.norm = normalization(channels) @@ -313,8 +327,10 @@ class AttentionBlock(nn.Module): self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) def forward(self, x): - return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! - #return pt_checkpoint(self._forward, x) # pytorch + return checkpoint( + self._forward, (x, ), self.parameters(), True + ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + # return pt_checkpoint(self._forward, x) # pytorch def _forward(self, x): b, c, *spatial = x.shape @@ -341,7 +357,7 @@ def count_flops_attn(model, _x, y): # We perform two matmuls with the same number of ops. # The first computes the weight matrix, the second computes # the combination of the value vectors. - matmul_ops = 2 * b * (num_spatial ** 2) * c + matmul_ops = 2 * b * (num_spatial**2) * c model.total_ops += th.DoubleTensor([matmul_ops]) @@ -363,13 +379,14 @@ class QKVAttentionLegacy(nn.Module): bs, width, length = qkv.shape assert width % (3 * self.n_heads) == 0 ch = width // (3 * self.n_heads) - q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split( + ch, dim=1) scale = 1 / math.sqrt(math.sqrt(ch)) weight = th.einsum( - "bct,bcs->bts", q * scale, k * scale - ) # More stable with f16 than dividing afterwards + 'bct,bcs->bts', q * scale, + k * scale) # More stable with f16 than dividing afterwards weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) - a = th.einsum("bts,bcs->bct", weight, v) + a = th.einsum('bts,bcs->bct', weight, v) return a.reshape(bs, -1, length) @staticmethod @@ -398,12 +415,13 @@ class QKVAttention(nn.Module): q, k, v = qkv.chunk(3, dim=1) scale = 1 / math.sqrt(math.sqrt(ch)) weight = th.einsum( - "bct,bcs->bts", + 'bct,bcs->bts', (q * scale).view(bs * self.n_heads, ch, length), (k * scale).view(bs * self.n_heads, ch, length), ) # More stable with f16 than dividing afterwards weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) - a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + a = th.einsum('bts,bcs->bct', weight, + v.reshape(bs * self.n_heads, ch, length)) return a.reshape(bs, -1, length) @staticmethod @@ -442,40 +460,42 @@ class UNetModel(nn.Module): """ def __init__( - self, - image_size, - in_channels, - model_channels, - out_channels, - num_res_blocks, - attention_resolutions, - dropout=0, - channel_mult=(1, 2, 4, 8), - conv_resample=True, - dims=2, - num_classes=None, - use_checkpoint=False, - use_fp16=False, - num_heads=-1, - num_head_channels=-1, - num_heads_upsample=-1, - use_scale_shift_norm=False, - resblock_updown=False, - use_new_attention_order=False, - use_spatial_transformer=False, # custom transformer support - transformer_depth=1, # custom transformer support - context_dim=None, # custom transformer support - n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model - legacy=True, - disable_self_attentions=None, - num_attention_blocks=None - ): + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + disable_self_attentions=None, + num_attention_blocks=None): super().__init__() if use_spatial_transformer: - assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your \ + cross-attention conditioning...' if context_dim is not None: - assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your \ + cross-attention conditioning...' + from omegaconf.listconfig import ListConfig if type(context_dim) == ListConfig: context_dim = list(context_dim) @@ -497,20 +517,28 @@ class UNetModel(nn.Module): self.num_res_blocks = len(channel_mult) * [num_res_blocks] else: if len(num_res_blocks) != len(channel_mult): - raise ValueError("provide num_res_blocks either as an int (globally constant) or " - "as a list/tuple (per-level) with the same length as channel_mult") + raise ValueError( + 'provide num_res_blocks either as an int (globally constant) or ' + 'as a list/tuple (per-level) with the same length as channel_mult' + ) self.num_res_blocks = num_res_blocks - #self.num_res_blocks = num_res_blocks + # self.num_res_blocks = num_res_blocks if disable_self_attentions is not None: # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not assert len(disable_self_attentions) == len(channel_mult) if num_attention_blocks is not None: assert len(num_attention_blocks) == len(self.num_res_blocks) - assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) - print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " - f"This option has LESS priority than attention_resolutions {attention_resolutions}, " - f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " - f"attention will still not be set.") # todo: convert to warning + assert all( + map( + lambda i: self.num_res_blocks[i] >= num_attention_blocks[i + ], + range(len(num_attention_blocks)))) + print( + f'Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. ' + f'This option has LESS priority than attention_resolutions {attention_resolutions}, ' + f'i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, ' + f'attention will still not be set.' + ) # todo: convert to warning self.attention_resolutions = attention_resolutions self.dropout = dropout @@ -534,13 +562,10 @@ class UNetModel(nn.Module): if self.num_classes is not None: self.label_emb = nn.Embedding(num_classes, time_embed_dim) - self.input_blocks = nn.ModuleList( - [ - TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, 3, padding=1) - ) - ] - ) # 0 + self.input_blocks = nn.ModuleList([ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1)) + ]) # 0 self._feature_size = model_channels input_block_chans = [model_channels] ch = model_channels @@ -559,21 +584,22 @@ class UNetModel(nn.Module): ) ] ch = mult * model_channels - if ds in attention_resolutions: # always True + if ds in attention_resolutions: # always True if num_head_channels == -1: dim_head = ch // num_heads else: num_heads = ch // num_head_channels dim_head = num_head_channels if legacy: - #num_heads = 1 + # num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels if exists(disable_self_attentions): disabled_sa = disable_self_attentions[level] else: disabled_sa = False - if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: + if not exists(num_attention_blocks + ) or nr < num_attention_blocks[level]: layers.append( AttentionBlock( ch, @@ -581,11 +607,14 @@ class UNetModel(nn.Module): num_heads=num_heads, num_head_channels=dim_head, use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, - disable_self_attn=disabled_sa - ) - ) + ) if not use_spatial_transformer else + SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa)) self.input_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch input_block_chans.append(ch) @@ -602,12 +631,8 @@ class UNetModel(nn.Module): use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, down=True, - ) - if resblock_updown - else Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch - ) - ) + ) if resblock_updown else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch)) ) ch = out_ch input_block_chans.append(ch) @@ -620,7 +645,7 @@ class UNetModel(nn.Module): num_heads = ch // num_head_channels dim_head = num_head_channels if legacy: - #num_heads = 1 + # num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels self.middle_block = TimestepEmbedSequential( ResBlock( @@ -637,9 +662,13 @@ class UNetModel(nn.Module): num_heads=num_heads, num_head_channels=dim_head, use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn - ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim - ), + ) if not use_spatial_transformer else + SpatialTransformer( # always uses a self-attn + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim), ResBlock( ch, time_embed_dim, @@ -674,14 +703,15 @@ class UNetModel(nn.Module): num_heads = ch // num_head_channels dim_head = num_head_channels if legacy: - #num_heads = 1 + # num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels if exists(disable_self_attentions): disabled_sa = disable_self_attentions[level] else: disabled_sa = False - if not exists(num_attention_blocks) or i < num_attention_blocks[level]: + if not exists(num_attention_blocks + ) or i < num_attention_blocks[level]: layers.append( AttentionBlock( ch, @@ -689,11 +719,14 @@ class UNetModel(nn.Module): num_heads=num_heads_upsample, num_head_channels=dim_head, use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, - disable_self_attn=disabled_sa - ) - ) + ) if not use_spatial_transformer else + SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa)) if level and i == self.num_res_blocks[level]: out_ch = ch layers.append( @@ -706,10 +739,8 @@ class UNetModel(nn.Module): use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, up=True, - ) - if resblock_updown - else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) - ) + ) if resblock_updown else Upsample( + ch, conv_resample, dims=dims, out_channels=out_ch)) ds //= 2 self.output_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch @@ -717,14 +748,15 @@ class UNetModel(nn.Module): self.out = nn.Sequential( normalization(ch), nn.SiLU(), - zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + zero_module( + conv_nd(dims, model_channels, out_channels, 3, padding=1)), ) if self.predict_codebook_ids: self.id_predictor = nn.Sequential( - normalization(ch), - conv_nd(dims, model_channels, n_embed, 1), - #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits - ) + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) def convert_to_fp16(self): """ @@ -742,7 +774,7 @@ class UNetModel(nn.Module): self.middle_block.apply(convert_module_to_f32) self.output_blocks.apply(convert_module_to_f32) - def forward(self, x, timesteps=None, context=None, y=None,**kwargs): + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): """ Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. @@ -753,18 +785,19 @@ class UNetModel(nn.Module): """ assert (y is not None) == ( self.num_classes is not None - ), "must specify y if and only if the model is class-conditional" + ), 'must specify y if and only if the model is class-conditional' hs = [] - t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) # N - emb = self.time_embed(t_emb) # + t_emb = timestep_embedding( + timesteps, self.model_channels, repeat_only=False) # N + emb = self.time_embed(t_emb) # if self.num_classes is not None: - assert y.shape == (x.shape[0],) + assert y.shape == (x.shape[0], ) emb = emb + self.label_emb(y) h = x.type(self.dtype) for module in self.input_blocks: - h = module(h, emb, context) # conv + h = module(h, emb, context) # conv hs.append(h) h = self.middle_block(h, emb, context) for module in self.output_blocks: @@ -783,30 +816,28 @@ class EncoderUNetModel(nn.Module): For usage, see UNet. """ - def __init__( - self, - image_size, - in_channels, - model_channels, - out_channels, - num_res_blocks, - attention_resolutions, - dropout=0, - channel_mult=(1, 2, 4, 8), - conv_resample=True, - dims=2, - use_checkpoint=False, - use_fp16=False, - num_heads=1, - num_head_channels=-1, - num_heads_upsample=-1, - use_scale_shift_norm=False, - resblock_updown=False, - use_new_attention_order=False, - pool="adaptive", - *args, - **kwargs - ): + def __init__(self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + pool='adaptive', + *args, + **kwargs): super().__init__() if num_heads_upsample == -1: @@ -833,13 +864,10 @@ class EncoderUNetModel(nn.Module): linear(time_embed_dim, time_embed_dim), ) - self.input_blocks = nn.ModuleList( - [ - TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, 3, padding=1) - ) - ] - ) + self.input_blocks = nn.ModuleList([ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1)) + ]) self._feature_size = model_channels input_block_chans = [model_channels] ch = model_channels @@ -866,8 +894,7 @@ class EncoderUNetModel(nn.Module): num_heads=num_heads, num_head_channels=num_head_channels, use_new_attention_order=use_new_attention_order, - ) - ) + )) self.input_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch input_block_chans.append(ch) @@ -884,12 +911,8 @@ class EncoderUNetModel(nn.Module): use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, down=True, - ) - if resblock_updown - else Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch - ) - ) + ) if resblock_updown else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch)) ) ch = out_ch input_block_chans.append(ch) @@ -923,7 +946,7 @@ class EncoderUNetModel(nn.Module): ) self._feature_size += ch self.pool = pool - if pool == "adaptive": + if pool == 'adaptive': self.out = nn.Sequential( normalization(ch), nn.SiLU(), @@ -931,22 +954,21 @@ class EncoderUNetModel(nn.Module): zero_module(conv_nd(dims, ch, out_channels, 1)), nn.Flatten(), ) - elif pool == "attention": + elif pool == 'attention': assert num_head_channels != -1 self.out = nn.Sequential( normalization(ch), nn.SiLU(), - AttentionPool2d( - (image_size // ds), ch, num_head_channels, out_channels - ), + AttentionPool2d((image_size // ds), ch, num_head_channels, + out_channels), ) - elif pool == "spatial": + elif pool == 'spatial': self.out = nn.Sequential( nn.Linear(self._feature_size, 2048), nn.ReLU(), nn.Linear(2048, self.out_channels), ) - elif pool == "spatial_v2": + elif pool == 'spatial_v2': self.out = nn.Sequential( nn.Linear(self._feature_size, 2048), normalization(2048), @@ -954,7 +976,7 @@ class EncoderUNetModel(nn.Module): nn.Linear(2048, self.out_channels), ) else: - raise NotImplementedError(f"Unexpected {pool} pooling") + raise NotImplementedError(f'Unexpected {pool} pooling') def convert_to_fp16(self): """ @@ -977,20 +999,20 @@ class EncoderUNetModel(nn.Module): :param timesteps: a 1-D batch of timesteps. :return: an [N x K] Tensor of outputs. """ - emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + emb = self.time_embed( + timestep_embedding(timesteps, self.model_channels)) results = [] h = x.type(self.dtype) for module in self.input_blocks: h = module(h, emb) - if self.pool.startswith("spatial"): + if self.pool.startswith('spatial'): results.append(h.type(x.dtype).mean(dim=(2, 3))) h = self.middle_block(h, emb) - if self.pool.startswith("spatial"): + if self.pool.startswith('spatial'): results.append(h.type(x.dtype).mean(dim=(2, 3))) h = th.cat(results, axis=-1) return self.out(h) else: h = h.type(x.dtype) return self.out(h) - diff --git a/modelscope/models/cv/image_to_3d/ldm/modules/diffusionmodules/util.py b/modelscope/models/cv/image_to_3d/ldm/modules/diffusionmodules/util.py index bd059502..a63d05a3 100644 --- a/modelscope/models/cv/image_to_3d/ldm/modules/diffusionmodules/util.py +++ b/modelscope/models/cv/image_to_3d/ldm/modules/diffusionmodules/util.py @@ -7,50 +7,65 @@ # # thanks! - -import os import math +import os + +import numpy as np import torch import torch.nn as nn -import numpy as np from einops import repeat from modelscope.models.cv.image_to_3d.ldm.util import instantiate_from_config -def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): - if schedule == "linear": +def make_beta_schedule(schedule, + n_timestep, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3): + if schedule == 'linear': betas = ( - torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 - ) + torch.linspace( + linear_start**0.5, + linear_end**0.5, + n_timestep, + dtype=torch.float64)**2) - elif schedule == "cosine": + elif schedule == 'cosine': timesteps = ( - torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s - ) + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + + cosine_s) alphas = timesteps / (1 + cosine_s) * np.pi / 2 alphas = torch.cos(alphas).pow(2) alphas = alphas / alphas[0] betas = 1 - alphas[1:] / alphas[:-1] betas = np.clip(betas, a_min=0, a_max=0.999) - elif schedule == "sqrt_linear": - betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) - elif schedule == "sqrt": - betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + elif schedule == 'sqrt_linear': + betas = torch.linspace( + linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == 'sqrt': + betas = torch.linspace( + linear_start, linear_end, n_timestep, dtype=torch.float64)**0.5 else: raise ValueError(f"schedule '{schedule}' unknown.") return betas.numpy() -def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): +def make_ddim_timesteps(ddim_discr_method, + num_ddim_timesteps, + num_ddpm_timesteps, + verbose=True): if ddim_discr_method == 'uniform': c = num_ddpm_timesteps // num_ddim_timesteps ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) elif ddim_discr_method == 'quad': - ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), + num_ddim_timesteps))**2).astype(int) else: - raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + raise NotImplementedError( + f'There is no ddim discretization method called "{ddim_discr_method}"' + ) # assert ddim_timesteps.shape[0] == num_ddim_timesteps # add one to get the final alpha values right (the ones from first scale to data during sampling) @@ -60,17 +75,26 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep return steps_out -def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): +def make_ddim_sampling_parameters(alphacums, + ddim_timesteps, + eta, + verbose=True): # select alphas for computing the variance schedule alphas = alphacums[ddim_timesteps] - alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + alphas_prev = np.asarray([alphacums[0]] + + alphacums[ddim_timesteps[:-1]].tolist()) # according the the formula provided in https://arxiv.org/abs/2010.02502 - sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * # noqa + (1 - alphas / alphas_prev)) # noqa if verbose: - print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') - print(f'For the chosen value of eta, which is {eta}, ' - f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + print( + f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}' + ) + print( + f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}' + ) return sigmas, alphas, alphas_prev @@ -96,7 +120,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): def extract_into_tensor(a, t, x_shape): b, *_ = t.shape out = a.gather(-1, t) - return out.reshape(b, *((1,) * (len(x_shape) - 1))) + return out.reshape(b, *((1, ) * (len(x_shape) - 1))) def checkpoint(func, inputs, params, flag): @@ -117,6 +141,7 @@ def checkpoint(func, inputs, params, flag): class CheckpointFunction(torch.autograd.Function): + @staticmethod def forward(ctx, run_function, length, *args): ctx.run_function = run_function @@ -129,7 +154,9 @@ class CheckpointFunction(torch.autograd.Function): @staticmethod def backward(ctx, *output_grads): - ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + ctx.input_tensors = [ + x.detach().requires_grad_(True) for x in ctx.input_tensors + ] with torch.enable_grad(): # Fixes a bug where the first op in run_function modifies the # Tensor storage in place, which is not allowed for detach()'d @@ -160,12 +187,14 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): if not repeat_only: half = dim // 2 freqs = torch.exp( - -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half - ).to(device=timesteps.device) + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half).to(device=timesteps.device) args = timesteps[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1) else: embedding = repeat(timesteps, 'b -> b d', d=dim) return embedding @@ -207,14 +236,17 @@ def normalization(channels): # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. class SiLU(nn.Module): + def forward(self, x): return x * torch.sigmoid(x) class GroupNorm32(nn.GroupNorm): + def forward(self, x): return super().forward(x.float()).type(x.dtype) + def conv_nd(dims, *args, **kwargs): """ Create a 1D, 2D, or 3D convolution module. @@ -225,7 +257,7 @@ def conv_nd(dims, *args, **kwargs): return nn.Conv2d(*args, **kwargs) elif dims == 3: return nn.Conv3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") + raise ValueError(f'unsupported dimensions: {dims}') def linear(*args, **kwargs): @@ -245,7 +277,7 @@ def avg_pool_nd(dims, *args, **kwargs): return nn.AvgPool2d(*args, **kwargs) elif dims == 3: return nn.AvgPool3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") + raise ValueError(f'unsupported dimensions: {dims}') class HybridConditioner(nn.Module): @@ -253,7 +285,8 @@ class HybridConditioner(nn.Module): def __init__(self, c_concat_config, c_crossattn_config): super().__init__() self.concat_conditioner = instantiate_from_config(c_concat_config) - self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + self.crossattn_conditioner = instantiate_from_config( + c_crossattn_config) def forward(self, c_concat, c_crossattn): c_concat = self.concat_conditioner(c_concat) @@ -262,6 +295,13 @@ class HybridConditioner(nn.Module): def noise_like(shape, device, repeat=False): - repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) - noise = lambda: torch.randn(shape, device=device) - return repeat_noise() if repeat else noise() \ No newline at end of file + + def repeat_noise(): + return torch.randn((1, *shape[1:]), + device=device).repeat(shape[0], + *((1, ) * (len(shape) - 1))) + + def noise(): + return torch.randn(shape, device=device) + + return repeat_noise() if repeat else noise() diff --git a/modelscope/models/cv/image_to_3d/ldm/modules/distributions/distributions.py b/modelscope/models/cv/image_to_3d/ldm/modules/distributions/distributions.py index f2b8ef90..24cbbbc8 100644 --- a/modelscope/models/cv/image_to_3d/ldm/modules/distributions/distributions.py +++ b/modelscope/models/cv/image_to_3d/ldm/modules/distributions/distributions.py @@ -1,8 +1,9 @@ -import torch import numpy as np +import torch class AbstractDistribution: + def sample(self): raise NotImplementedError() @@ -11,6 +12,7 @@ class AbstractDistribution: class DiracDistribution(AbstractDistribution): + def __init__(self, value): self.value = value @@ -22,6 +24,7 @@ class DiracDistribution(AbstractDistribution): class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): self.parameters = parameters self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) @@ -30,10 +33,12 @@ class DiagonalGaussianDistribution(object): self.std = torch.exp(0.5 * self.logvar) self.var = torch.exp(self.logvar) if self.deterministic: - self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + self.var = self.std = torch.zeros_like( + self.mean).to(device=self.parameters.device) def sample(self): - x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + x = self.mean + self.std * torch.randn( + self.mean.shape).to(device=self.parameters.device) return x def kl(self, other=None): @@ -41,21 +46,22 @@ class DiagonalGaussianDistribution(object): return torch.Tensor([0.]) else: if other is None: - return 0.5 * torch.sum(torch.pow(self.mean, 2) - + self.var - 1.0 - self.logvar, - dim=[1, 2, 3]) + return 0.5 * torch.sum( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) else: return 0.5 * torch.sum( torch.pow(self.mean - other.mean, 2) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar, dim=[1, 2, 3]) - def nll(self, sample, dims=[1,2,3]): + def nll(self, sample, dims=[1, 2, 3]): if self.deterministic: return torch.Tensor([0.]) logtwopi = np.log(2.0 * np.pi) return 0.5 * torch.sum( - logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + logtwopi + self.logvar + + torch.pow(sample - self.mean, 2) / self.var, dim=dims) def mode(self): @@ -64,7 +70,8 @@ class DiagonalGaussianDistribution(object): def normal_kl(mean1, logvar1, mean2, logvar2): """ - source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + (source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/ + guided_diffusion/losses.py#L12) Compute the KL divergence between two gaussians. Shapes are automatically broadcasted, so batches can be compared to scalars, among other use cases. @@ -74,7 +81,7 @@ def normal_kl(mean1, logvar1, mean2, logvar2): if isinstance(obj, torch.Tensor): tensor = obj break - assert tensor is not None, "at least one argument must be a Tensor" + assert tensor is not None, 'at least one argument must be a Tensor' # Force variances to be Tensors. Broadcasting helps convert scalars to # Tensors, but it does not work for torch.exp(). @@ -84,9 +91,5 @@ def normal_kl(mean1, logvar1, mean2, logvar2): ] return 0.5 * ( - -1.0 - + logvar2 - - logvar1 - + torch.exp(logvar1 - logvar2) - + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) - ) + -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + # noqa + ((mean1 - mean2)**2) * torch.exp(-logvar2)) # noqa diff --git a/modelscope/models/cv/image_to_3d/ldm/modules/encoders/clip/clip.py b/modelscope/models/cv/image_to_3d/ldm/modules/encoders/clip/clip.py index 0b546d32..c61c3432 100644 --- a/modelscope/models/cv/image_to_3d/ldm/modules/encoders/clip/clip.py +++ b/modelscope/models/cv/image_to_3d/ldm/modules/encoders/clip/clip.py @@ -2,15 +2,17 @@ import hashlib import os import urllib import warnings -from typing import Any, Union, List -from pkg_resources import packaging +from typing import Any, List, Union import torch from PIL import Image -from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +from pkg_resources import packaging +from torchvision.transforms import (CenterCrop, Compose, Normalize, Resize, + ToTensor) from tqdm import tqdm -from modelscope.models.cv.image_to_3d.ldm.modules.encoders.clip.model import build_model +from modelscope.models.cv.image_to_3d.ldm.modules.encoders.clip.model import \ + build_model try: from torchvision.transforms import InterpolationMode @@ -18,23 +20,40 @@ try: except ImportError: BICUBIC = Image.BICUBIC +if packaging.version.parse( + torch.__version__) < packaging.version.parse('1.7.1'): + warnings.warn('PyTorch version 1.7.1 or higher is recommended') -if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): - warnings.warn("PyTorch version 1.7.1 or higher is recommended") - - -__all__ = ["available_models", "load"] +__all__ = ['available_models', 'load'] _MODELS = { - "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", - "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", - "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", - "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", - "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", - "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", - "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", - "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", - "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", + 'RN50': + 'https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/\ + RN50.pt', + 'RN101': + 'https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/\ + RN101.pt', + 'RN50x4': + 'https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/\ + RN50x4.pt', + 'RN50x16': + 'https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/\ + RN50x16.pt', + 'RN50x64': + 'https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/\ + RN50x64.pt', + 'ViT-B/32': + 'https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/\ + ViT-B-32.pt', + 'ViT-B/16': + 'https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/\ + ViT-B-16.pt', + 'ViT-L/14': + 'https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/\ + ViT-L-14.pt', + 'ViT-L/14@336px': + 'https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/\ + ViT-L-14-336px.pt', } @@ -42,20 +61,30 @@ def _download(url: str, root: str): os.makedirs(root, exist_ok=True) filename = os.path.basename(url) - expected_sha256 = url.split("/")[-2] + expected_sha256 = url.split('/')[-2] download_target = os.path.join(root, filename) if os.path.exists(download_target) and not os.path.isfile(download_target): - raise RuntimeError(f"{download_target} exists and is not a regular file") + raise RuntimeError( + f'{download_target} exists and is not a regular file') if os.path.isfile(download_target): - if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: + if hashlib.sha256(open(download_target, + 'rb').read()).hexdigest() == expected_sha256: return download_target else: - warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + warnings.warn( + f'{download_target} exists, but the SHA256 checksum does not match; re-downloading the file' + ) - with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: - with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: + with urllib.request.urlopen(url) as source, open(download_target, + 'wb') as output: + with tqdm( + total=int(source.info().get('Content-Length')), + ncols=80, + unit='iB', + unit_scale=True, + unit_divisor=1024) as loop: while True: buffer = source.read(8192) if not buffer: @@ -64,14 +93,17 @@ def _download(url: str, root: str): output.write(buffer) loop.update(len(buffer)) - if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: - raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") + if hashlib.sha256(open(download_target, + 'rb').read()).hexdigest() != expected_sha256: + raise RuntimeError( + 'Model has been downloaded but the SHA256 checksum does not not match' + ) return download_target def _convert_image_to_rgb(image): - return image.convert("RGB") + return image.convert('RGB') def _transform(n_px): @@ -80,7 +112,8 @@ def _transform(n_px): CenterCrop(n_px), _convert_image_to_rgb, ToTensor(), - Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711)), ]) @@ -89,7 +122,11 @@ def available_models() -> List[str]: return list(_MODELS.keys()) -def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): +def load(name: str, + device: Union[str, torch.device] = 'cuda' + if torch.cuda.is_available() else 'cpu', + jit: bool = False, + download_root: str = None): """Load a CLIP model Parameters @@ -115,37 +152,47 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input """ if name in _MODELS: - model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) + model_path = _download( + _MODELS[name], download_root + or os.path.expanduser('~/.cache/clip')) elif os.path.isfile(name): model_path = name else: - raise RuntimeError(f"Model {name} not found; available models = {available_models()}") + raise RuntimeError( + f'Model {name} not found; available models = {available_models()}') with open(model_path, 'rb') as opened_file: try: # loading JIT archive - model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() + model = torch.jit.load( + opened_file, map_location=device if jit else 'cpu').eval() state_dict = None except RuntimeError: # loading saved state dict if jit: - warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + warnings.warn( + f'File {model_path} is not a JIT archive. Loading as a state dict instead' + ) jit = False - state_dict = torch.load(opened_file, map_location="cpu") + state_dict = torch.load(opened_file, map_location='cpu') if not jit: model = build_model(state_dict or model.state_dict()).to(device) - if str(device) == "cpu": + if str(device) == 'cpu': model.float() return model, _transform(model.visual.input_resolution) # patch the device names - device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) - device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] + device_holder = torch.jit.trace( + lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [ + n for n in device_holder.graph.findAllNodes('prim::Constant') + if 'Device' in repr(n) + ][-1] def _node_get(node: torch._C.Node, key: str): """Gets attributes of a node which is polymorphic over return type. - + From https://github.com/pytorch/pytorch/pull/82628 """ sel = node.kindOf(key) @@ -153,16 +200,17 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a def patch_device(module): try: - graphs = [module.graph] if hasattr(module, "graph") else [] + graphs = [module.graph] if hasattr(module, 'graph') else [] except RuntimeError: graphs = [] - if hasattr(module, "forward1"): + if hasattr(module, 'forward1'): graphs.append(module.forward1.graph) for graph in graphs: - for node in graph.findAllNodes("prim::Constant"): - if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"): + for node in graph.findAllNodes('prim::Constant'): + if 'value' in node.attributeNames() and str( + _node_get(node, 'value')).startswith('cuda'): node.copyAttributes(device_node) model.apply(patch_device) @@ -170,25 +218,28 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a patch_device(model.encode_text) # patch dtype to float32 on CPU - if str(device) == "cpu": - float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) - float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + if str(device) == 'cpu': + float_holder = torch.jit.trace( + lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode('aten::to').inputs())[1] float_node = float_input.node() def patch_float(module): try: - graphs = [module.graph] if hasattr(module, "graph") else [] + graphs = [module.graph] if hasattr(module, 'graph') else [] except RuntimeError: graphs = [] - if hasattr(module, "forward1"): + if hasattr(module, 'forward1'): graphs.append(module.forward1.graph) for graph in graphs: - for node in graph.findAllNodes("aten::to"): + for node in graph.findAllNodes('aten::to'): inputs = list(node.inputs()) - for i in [1, 2]: # dtype can be the second or third argument to aten::to() - if _node_get(inputs[i].node(), "value") == 5: + for i in [ + 1, 2 + ]: # dtype can be the second or third argument to aten::to() + if _node_get(inputs[i].node(), 'value') == 5: inputs[i].node().copyAttributes(float_node) model.apply(patch_float) diff --git a/modelscope/models/cv/image_to_3d/ldm/modules/encoders/clip/model.py b/modelscope/models/cv/image_to_3d/ldm/modules/encoders/clip/model.py index 232b7792..c3d0471f 100644 --- a/modelscope/models/cv/image_to_3d/ldm/modules/encoders/clip/model.py +++ b/modelscope/models/cv/image_to_3d/ldm/modules/encoders/clip/model.py @@ -33,11 +33,16 @@ class Bottleneck(nn.Module): if stride > 1 or inplanes != planes * Bottleneck.expansion: # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 - self.downsample = nn.Sequential(OrderedDict([ - ("-1", nn.AvgPool2d(stride)), - ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), - ("1", nn.BatchNorm2d(planes * self.expansion)) - ])) + self.downsample = nn.Sequential( + OrderedDict([('-1', nn.AvgPool2d(stride)), + ('0', + nn.Conv2d( + inplanes, + planes * self.expansion, + 1, + stride=1, + bias=False)), + ('1', nn.BatchNorm2d(planes * self.expansion))])) def forward(self, x: torch.Tensor): identity = x @@ -56,9 +61,15 @@ class Bottleneck(nn.Module): class AttentionPool2d(nn.Module): - def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + + def __init__(self, + spacial_dim: int, + embed_dim: int, + num_heads: int, + output_dim: int = None): super().__init__() - self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.positional_embedding = nn.Parameter( + torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5) self.k_proj = nn.Linear(embed_dim, embed_dim) self.q_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) @@ -70,14 +81,17 @@ class AttentionPool2d(nn.Module): x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC x, _ = F.multi_head_attention_forward( - query=x[:1], key=x, value=x, + query=x[:1], + key=x, + value=x, embed_dim_to_check=x.shape[-1], num_heads=self.num_heads, q_proj_weight=self.q_proj.weight, k_proj_weight=self.k_proj.weight, v_proj_weight=self.v_proj.weight, in_proj_weight=None, - in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + in_proj_bias=torch.cat( + [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), bias_k=None, bias_v=None, add_zero_attn=False, @@ -86,8 +100,7 @@ class AttentionPool2d(nn.Module): out_proj_bias=self.c_proj.bias, use_separate_proj_weight=True, training=self.training, - need_weights=False - ) + need_weights=False) return x.squeeze(0) @@ -99,19 +112,27 @@ class ModifiedResNet(nn.Module): - The final pooling layer is a QKV attention instead of an average pool """ - def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): + def __init__(self, + layers, + output_dim, + heads, + input_resolution=224, + width=64): super().__init__() self.output_dim = output_dim self.input_resolution = input_resolution # the 3-layer stem - self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.conv1 = nn.Conv2d( + 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(width // 2) self.relu1 = nn.ReLU(inplace=True) - self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.conv2 = nn.Conv2d( + width // 2, width // 2, kernel_size=3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(width // 2) self.relu2 = nn.ReLU(inplace=True) - self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.conv3 = nn.Conv2d( + width // 2, width, kernel_size=3, padding=1, bias=False) self.bn3 = nn.BatchNorm2d(width) self.relu3 = nn.ReLU(inplace=True) self.avgpool = nn.AvgPool2d(2) @@ -124,7 +145,8 @@ class ModifiedResNet(nn.Module): self.layer4 = self._make_layer(width * 8, layers[3], stride=2) embed_dim = width * 32 # the ResNet feature dimension - self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, + heads, output_dim) def _make_layer(self, planes, blocks, stride=1): layers = [Bottleneck(self._inplanes, planes, stride)] @@ -136,6 +158,7 @@ class ModifiedResNet(nn.Module): return nn.Sequential(*layers) def forward(self, x): + def stem(x): x = self.relu1(self.bn1(self.conv1(x))) x = self.relu2(self.bn2(self.conv2(x))) @@ -164,27 +187,34 @@ class LayerNorm(nn.LayerNorm): class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): return x * torch.sigmoid(1.702 * x) class ResidualAttentionBlock(nn.Module): - def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + + def __init__(self, + d_model: int, + n_head: int, + attn_mask: torch.Tensor = None): super().__init__() self.attn = nn.MultiheadAttention(d_model, n_head) self.ln_1 = LayerNorm(d_model) - self.mlp = nn.Sequential(OrderedDict([ - ("c_fc", nn.Linear(d_model, d_model * 4)), - ("gelu", QuickGELU()), - ("c_proj", nn.Linear(d_model * 4, d_model)) - ])) + self.mlp = nn.Sequential( + OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), + ('gelu', QuickGELU()), + ('c_proj', nn.Linear(d_model * 4, d_model))])) self.ln_2 = LayerNorm(d_model) self.attn_mask = attn_mask def attention(self, x: torch.Tensor): - self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None - return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + self.attn_mask = self.attn_mask.to( + dtype=x.dtype, + device=x.device) if self.attn_mask is not None else None + return self.attn( + x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] def forward(self, x: torch.Tensor): x = x + self.attention(self.ln_1(x)) @@ -193,26 +223,42 @@ class ResidualAttentionBlock(nn.Module): class Transformer(nn.Module): - def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): + + def __init__(self, + width: int, + layers: int, + heads: int, + attn_mask: torch.Tensor = None): super().__init__() self.width = width self.layers = layers - self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) + self.resblocks = nn.Sequential(*[ + ResidualAttentionBlock(width, heads, attn_mask) + for _ in range(layers) + ]) def forward(self, x: torch.Tensor): return self.resblocks(x) class VisionTransformer(nn.Module): - def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): + + def __init__(self, input_resolution: int, patch_size: int, width: int, + layers: int, heads: int, output_dim: int): super().__init__() self.input_resolution = input_resolution self.output_dim = output_dim - self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False) - scale = width ** -0.5 + scale = width**-0.5 self.class_embedding = nn.Parameter(scale * torch.randn(width)) - self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) + self.positional_embedding = nn.Parameter(scale * torch.randn( + (input_resolution // patch_size)**2 + 1, width)) self.ln_pre = LayerNorm(width) self.transformer = Transformer(width, layers, heads) @@ -222,9 +268,13 @@ class VisionTransformer(nn.Module): def forward(self, x: torch.Tensor): x = self.conv1(x) # shape = [*, width, grid, grid] - x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.reshape(x.shape[0], x.shape[1], + -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] - x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + torch_zeros = torch.zeros( + x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device) + x = torch.cat([self.class_embedding.to(x.dtype) + torch_zeros, x], + dim=1) # shape = [*, grid ** 2 + 1, width] x = x + self.positional_embedding.to(x.dtype) x = self.ln_pre(x) @@ -241,20 +291,21 @@ class VisionTransformer(nn.Module): class CLIP(nn.Module): - def __init__(self, - embed_dim: int, - # vision - image_resolution: int, - vision_layers: Union[Tuple[int, int, int, int], int], - vision_width: int, - vision_patch_size: int, - # text - context_length: int, - vocab_size: int, - transformer_width: int, - transformer_heads: int, - transformer_layers: int - ): + + def __init__( + self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int): super().__init__() self.context_length = context_length @@ -266,8 +317,7 @@ class CLIP(nn.Module): output_dim=embed_dim, heads=vision_heads, input_resolution=image_resolution, - width=vision_width - ) + width=vision_width) else: vision_heads = vision_width // 64 self.visual = VisionTransformer( @@ -276,22 +326,22 @@ class CLIP(nn.Module): width=vision_width, layers=vision_layers, heads=vision_heads, - output_dim=embed_dim - ) + output_dim=embed_dim) self.transformer = Transformer( width=transformer_width, layers=transformer_layers, heads=transformer_heads, - attn_mask=self.build_attention_mask() - ) + attn_mask=self.build_attention_mask()) self.vocab_size = vocab_size self.token_embedding = nn.Embedding(vocab_size, transformer_width) - self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) + self.positional_embedding = nn.Parameter( + torch.empty(self.context_length, transformer_width)) self.ln_final = LayerNorm(transformer_width) - self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + self.text_projection = nn.Parameter( + torch.empty(transformer_width, embed_dim)) self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.initialize_parameters() @@ -302,20 +352,24 @@ class CLIP(nn.Module): if isinstance(self.visual, ModifiedResNet): if self.visual.attnpool is not None: - std = self.visual.attnpool.c_proj.in_features ** -0.5 + std = self.visual.attnpool.c_proj.in_features**-0.5 nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) - for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: + for resnet_block in [ + self.visual.layer1, self.visual.layer2, self.visual.layer3, + self.visual.layer4 + ]: for name, param in resnet_block.named_parameters(): - if name.endswith("bn3.weight"): + if name.endswith('bn3.weight'): nn.init.zeros_(param) - proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) - attn_std = self.transformer.width ** -0.5 - fc_std = (2 * self.transformer.width) ** -0.5 + proj_std = (self.transformer.width**-0.5) * ( + (2 * self.transformer.layers)**-0.5) + attn_std = self.transformer.width**-0.5 + fc_std = (2 * self.transformer.width)**-0.5 for block in self.transformer.resblocks: nn.init.normal_(block.attn.in_proj_weight, std=attn_std) nn.init.normal_(block.attn.out_proj.weight, std=proj_std) @@ -323,13 +377,14 @@ class CLIP(nn.Module): nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) if self.text_projection is not None: - nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + nn.init.normal_( + self.text_projection, std=self.transformer.width**-0.5) def build_attention_mask(self): # lazily create causal attention mask, with full attention between the vision tokens # pytorch uses additive attention mask; fill with -inf mask = torch.empty(self.context_length, self.context_length) - mask.fill_(float("-inf")) + mask.fill_(float('-inf')) mask.triu_(1) # zero out the lower diagonal return mask @@ -341,7 +396,8 @@ class CLIP(nn.Module): return self.visual(image.type(self.dtype)) def encode_text(self, text): - x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] + x = self.token_embedding(text).type( + self.dtype) # [batch_size, n_ctx, d_model] x = x + self.positional_embedding.type(self.dtype) x = x.permute(1, 0, 2) # NLD -> LND @@ -351,7 +407,8 @@ class CLIP(nn.Module): # x.shape = [batch_size, n_ctx, transformer.width] # take features from the eot embedding (eot_token is the highest number in each sequence) - x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + x = x[torch.arange(x.shape[0]), + text.argmax(dim=-1)] @ self.text_projection return x @@ -360,7 +417,8 @@ class CLIP(nn.Module): text_features = self.encode_text(text) # normalized features - image_features = image_features / image_features.norm(dim=1, keepdim=True) + image_features = image_features / image_features.norm( + dim=1, keepdim=True) text_features = text_features / text_features.norm(dim=1, keepdim=True) # cosine similarity as logits @@ -375,21 +433,24 @@ class CLIP(nn.Module): def convert_weights(model: nn.Module): """Convert applicable model parameters to fp16""" - def _convert_weights_to_fp16(l): - if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): - l.weight.data = l.weight.data.half() - if l.bias is not None: - l.bias.data = l.bias.data.half() + def _convert_weights_to_fp16(layer): + if isinstance(layer, (nn.Conv1d, nn.Conv2d, nn.Linear)): + layer.weight.data = layer.weight.data.half() + if layer.bias is not None: + layer.bias.data = layer.bias.data.half() - if isinstance(l, nn.MultiheadAttention): - for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: - tensor = getattr(l, attr) + if isinstance(layer, nn.MultiheadAttention): + for attr in [ + *[f'{s}_proj_weight' for s in ['in', 'q', 'k', 'v']], + 'in_proj_bias', 'bias_k', 'bias_v' + ]: + tensor = getattr(layer, attr) if tensor is not None: tensor.data = tensor.data.half() - for name in ["text_projection", "proj"]: - if hasattr(l, name): - attr = getattr(l, name) + for name in ['text_projection', 'proj']: + if hasattr(layer, name): + attr = getattr(layer, name) if attr is not None: attr.data = attr.data.half() @@ -397,37 +458,51 @@ def convert_weights(model: nn.Module): def build_model(state_dict: dict): - vit = "visual.proj" in state_dict + vit = 'visual.proj' in state_dict if vit: - vision_width = state_dict["visual.conv1.weight"].shape[0] - vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) - vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] - grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + vision_width = state_dict['visual.conv1.weight'].shape[0] + vision_layers = len([ + k for k in state_dict.keys() + if k.startswith('visual.') and k.endswith('.attn.in_proj_weight') + ]) + vision_patch_size = state_dict['visual.conv1.weight'].shape[-1] + grid_size = round( + (state_dict['visual.positional_embedding'].shape[0] - 1)**0.5) image_resolution = vision_patch_size * grid_size else: - counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + counts: list = [ + len( + set( + k.split('.')[2] for k in state_dict + if k.startswith(f'visual.layer{b}'))) + for b in [1, 2, 3, 4] + ] vision_layers = tuple(counts) - vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] - output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_width = state_dict['visual.layer1.0.conv1.weight'].shape[0] + output_width = round( + (state_dict['visual.attnpool.positional_embedding'].shape[0] + - 1)**0.5) vision_patch_size = None - assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + assert output_width**2 + 1 == state_dict[ + 'visual.attnpool.positional_embedding'].shape[0] image_resolution = output_width * 32 - embed_dim = state_dict["text_projection"].shape[1] - context_length = state_dict["positional_embedding"].shape[0] - vocab_size = state_dict["token_embedding.weight"].shape[0] - transformer_width = state_dict["ln_final.weight"].shape[0] + embed_dim = state_dict['text_projection'].shape[1] + context_length = state_dict['positional_embedding'].shape[0] + vocab_size = state_dict['token_embedding.weight'].shape[0] + transformer_width = state_dict['ln_final.weight'].shape[0] transformer_heads = transformer_width // 64 - transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) + transformer_layers = len( + set( + k.split('.')[2] for k in state_dict + if k.startswith('transformer.resblocks'))) - model = CLIP( - embed_dim, - image_resolution, vision_layers, vision_width, vision_patch_size, - context_length, vocab_size, transformer_width, transformer_heads, transformer_layers - ) + model = CLIP(embed_dim, image_resolution, vision_layers, vision_width, + vision_patch_size, context_length, vocab_size, + transformer_width, transformer_heads, transformer_layers) - for key in ["input_resolution", "context_length", "vocab_size"]: + for key in ['input_resolution', 'context_length', 'vocab_size']: if key in state_dict: del state_dict[key] diff --git a/modelscope/models/cv/image_to_3d/ldm/modules/encoders/clip/simple_tokenizer.py b/modelscope/models/cv/image_to_3d/ldm/modules/encoders/clip/simple_tokenizer.py index 0a66286b..ffd0d092 100644 --- a/modelscope/models/cv/image_to_3d/ldm/modules/encoders/clip/simple_tokenizer.py +++ b/modelscope/models/cv/image_to_3d/ldm/modules/encoders/clip/simple_tokenizer.py @@ -9,7 +9,9 @@ import regex as re @lru_cache() def default_bpe(): - return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + return os.path.join( + os.path.dirname(os.path.abspath(__file__)), + 'bpe_simple_vocab_16e6.txt.gz') @lru_cache() @@ -23,13 +25,17 @@ def bytes_to_unicode(): To avoid that, we want lookup tables between utf-8 bytes and unicode strings. And avoids mapping to whitespace/control characters the bpe code barfs on. """ - bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + bs = list(range(ord('!'), + ord('~') + 1)) + list(range( + ord('¡'), + ord('¬') + 1)) + list(range(ord('®'), + ord('ÿ') + 1)) cs = bs[:] n = 0 for b in range(2**8): if b not in bs: bs.append(b) - cs.append(2**8+n) + cs.append(2**8 + n) n += 1 cs = [chr(n) for n in cs] return dict(zip(bs, cs)) @@ -60,34 +66,41 @@ def whitespace_clean(text): class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe()): self.byte_encoder = bytes_to_unicode() self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} - merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') - merges = merges[1:49152-256-2+1] + merges = gzip.open(bpe_path).read().decode('utf-8').split('\n') + merges = merges[1:49152 - 256 - 2 + 1] merges = [tuple(merge.split()) for merge in merges] vocab = list(bytes_to_unicode().values()) - vocab = vocab + [v+'' for v in vocab] + vocab = vocab + [v + '' for v in vocab] for merge in merges: vocab.append(''.join(merge)) vocab.extend(['<|startoftext|>', '<|endoftext|>']) self.encoder = dict(zip(vocab, range(len(vocab)))) self.decoder = {v: k for k, v in self.encoder.items()} self.bpe_ranks = dict(zip(merges, range(len(merges)))) - self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} - self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + self.cache = { + '<|startoftext|>': '<|startoftext|>', + '<|endoftext|>': '<|endoftext|>' + } + self.pat = re.compile( + r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE) def bpe(self, token): if token in self.cache: return self.cache[token] - word = tuple(token[:-1]) + ( token[-1] + '',) + word = tuple(token[:-1]) + (token[-1] + '', ) pairs = get_pairs(word) if not pairs: - return token+'' + return token + '' while True: - bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + bigram = min( + pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) if bigram not in self.bpe_ranks: break first, second = bigram @@ -98,12 +111,13 @@ class SimpleTokenizer(object): j = word.index(first, i) new_word.extend(word[i:j]) i = j - except: + except BaseException: new_word.extend(word[i:]) break - if word[i] == first and i < len(word)-1 and word[i+1] == second: - new_word.append(first+second) + if word[i] == first and i < len(word) - 1 and word[ + i + 1] == second: + new_word.append(first + second) i += 2 else: new_word.append(word[i]) @@ -122,11 +136,14 @@ class SimpleTokenizer(object): bpe_tokens = [] text = whitespace_clean(basic_clean(text)).lower() for token in re.findall(self.pat, text): - token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) - bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + token = ''.join(self.byte_encoder[b] + for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] + for bpe_token in self.bpe(token).split(' ')) return bpe_tokens def decode(self, tokens): text = ''.join([self.decoder[token] for token in tokens]) - text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + text = bytearray([self.byte_decoder[c] for c in text]).decode( + 'utf-8', errors='replace').replace('', ' ') return text diff --git a/modelscope/models/cv/image_to_3d/ldm/modules/encoders/modules.py b/modelscope/models/cv/image_to_3d/ldm/modules/encoders/modules.py index 9b62b1e0..d8fbc03d 100644 --- a/modelscope/models/cv/image_to_3d/ldm/modules/encoders/modules.py +++ b/modelscope/models/cv/image_to_3d/ldm/modules/encoders/modules.py @@ -1,28 +1,45 @@ +import random +from functools import partial + +import kornia +import kornia.augmentation as K +import numpy as np import torch import torch.nn as nn -import numpy as np -from functools import partial -import kornia +import torch.nn.functional as F +from torchvision import transforms +from transformers import (CLIPTextModel, CLIPTokenizer, CLIPVisionModel, + T5EncoderModel, T5Tokenizer) -from modelscope.models.cv.image_to_3d.ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test -from modelscope.models.cv.image_to_3d.ldm.util import default +from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.util import ( + extract_into_tensor, make_beta_schedule, noise_like) # import clip from modelscope.models.cv.image_to_3d.ldm.modules.encoders import clip +# TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test +from modelscope.models.cv.image_to_3d.ldm.modules.x_transformer import ( + Encoder, TransformerWrapper) +from modelscope.models.cv.image_to_3d.ldm.thirdp.psp.id_loss import IDFeatures +from modelscope.models.cv.image_to_3d.ldm.util import (default, + instantiate_from_config) class AbstractEncoder(nn.Module): + def __init__(self): super().__init__() def encode(self, *args, **kwargs): raise NotImplementedError + class IdentityEncoder(AbstractEncoder): def encode(self, x): return x + class FaceClipEncoder(AbstractEncoder): + def __init__(self, augment=True, retreival_key=None): super().__init__() self.encoder = FrozenCLIPImageEmbedder() @@ -35,16 +52,16 @@ class FaceClipEncoder(AbstractEncoder): x_offset = 125 if self.retreival_key: # Assumes retrieved image are packed into the second half of channels - face = img[:,3:,190:440,x_offset:(512-x_offset)] - other = img[:,:3,...].clone() + face = img[:, 3:, 190:440, x_offset:(512 - x_offset)] + other = img[:, :3, ...].clone() else: - face = img[:,:,190:440,x_offset:(512-x_offset)] + face = img[:, :, 190:440, x_offset:(512 - x_offset)] other = img.clone() if self.augment: face = K.RandomHorizontalFlip()(face) - other[:,:,190:440,x_offset:(512-x_offset)] *= 0 + other[:, :, 190:440, x_offset:(512 - x_offset)] *= 0 encodings = [ self.encoder.encode(face), self.encoder.encode(other), @@ -55,26 +72,32 @@ class FaceClipEncoder(AbstractEncoder): def encode(self, img): if isinstance(img, list): # Uncondition - return torch.zeros((1, 2, 768), device=self.encoder.model.visual.conv1.weight.device) + return torch.zeros( + (1, 2, 768), + device=self.encoder.model.visual.conv1.weight.device) return self(img) + class FaceIdClipEncoder(AbstractEncoder): + def __init__(self): super().__init__() self.encoder = FrozenCLIPImageEmbedder() for p in self.encoder.parameters(): p.requires_grad = False - self.id = FrozenFaceEncoder("/home/jpinkney/code/stable-diffusion/model_ir_se50.pth", augment=True) + self.id = FrozenFaceEncoder( + '/home/jpinkney/code/stable-diffusion/model_ir_se50.pth', + augment=True) def forward(self, img): encodings = [] with torch.no_grad(): - face = kornia.geometry.resize(img, (256, 256), - interpolation='bilinear', align_corners=True) + face = kornia.geometry.resize( + img, (256, 256), interpolation='bilinear', align_corners=True) other = img.clone() - other[:,:,184:452,122:396] *= 0 + other[:, :, 184:452, 122:396] *= 0 encodings = [ self.id.encode(face), self.encoder.encode(other), @@ -85,11 +108,15 @@ class FaceIdClipEncoder(AbstractEncoder): def encode(self, img): if isinstance(img, list): # Uncondition - return torch.zeros((1, 2, 768), device=self.encoder.model.visual.conv1.weight.device) + return torch.zeros( + (1, 2, 768), + device=self.encoder.model.visual.conv1.weight.device) return self(img) + class ClassEmbedder(nn.Module): + def __init__(self, embed_dim, n_classes=1000, key='class'): super().__init__() self.key = key @@ -106,11 +133,19 @@ class ClassEmbedder(nn.Module): class TransformerEmbedder(AbstractEncoder): """Some transformer encoder layers""" - def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): + + def __init__(self, + n_embed, + n_layer, + vocab_size, + max_seq_len=77, + device='cuda'): super().__init__() self.device = device - self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, - attn_layers=Encoder(dim=n_embed, depth=n_layer)) + self.transformer = TransformerWrapper( + num_tokens=vocab_size, + max_seq_len=max_seq_len, + attn_layers=Encoder(dim=n_embed, depth=n_layer)) def forward(self, tokens): tokens = tokens.to(self.device) # meh @@ -123,18 +158,25 @@ class TransformerEmbedder(AbstractEncoder): class BERTTokenizer(AbstractEncoder): """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" - def __init__(self, device="cuda", vq_interface=True, max_length=77): + + def __init__(self, device='cuda', vq_interface=True, max_length=77): super().__init__() from transformers import BertTokenizerFast # TODO: add to reuquirements - self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") + self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') self.device = device self.vq_interface = vq_interface self.max_length = max_length def forward(self, text): - batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, - return_overflowing_tokens=False, padding="max_length", return_tensors="pt") - tokens = batch_encoding["input_ids"].to(self.device) + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding='max_length', + return_tensors='pt') + tokens = batch_encoding['input_ids'].to(self.device) return tokens @torch.no_grad() @@ -150,20 +192,30 @@ class BERTTokenizer(AbstractEncoder): class BERTEmbedder(AbstractEncoder): """Uses the BERT tokenizr model and add some transformer encoder layers""" - def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, - device="cuda",use_tokenizer=True, embedding_dropout=0.0): + + def __init__(self, + n_embed, + n_layer, + vocab_size=30522, + max_seq_len=77, + device='cuda', + use_tokenizer=True, + embedding_dropout=0.0): super().__init__() self.use_tknz_fn = use_tokenizer if self.use_tknz_fn: - self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) + self.tknz_fn = BERTTokenizer( + vq_interface=False, max_length=max_seq_len) self.device = device - self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, - attn_layers=Encoder(dim=n_embed, depth=n_layer), - emb_dropout=embedding_dropout) + self.transformer = TransformerWrapper( + num_tokens=vocab_size, + max_seq_len=max_seq_len, + attn_layers=Encoder(dim=n_embed, depth=n_layer), + emb_dropout=embedding_dropout) def forward(self, text): if self.use_tknz_fn: - tokens = self.tknz_fn(text)#.to(self.device) + tokens = self.tknz_fn(text) # .to(self.device) else: tokens = text z = self.transformer(tokens, return_embeddings=True) @@ -174,8 +226,6 @@ class BERTEmbedder(AbstractEncoder): return self(text) -from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel - def disabled_train(self, mode=True): """Overwrite model.train with this function to make sure train/eval mode does not change anymore.""" @@ -184,24 +234,41 @@ def disabled_train(self, mode=True): class FrozenT5Embedder(AbstractEncoder): """Uses the T5 transformer encoder for text""" - def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl + + def __init__(self, + version='google/t5-v1_1-large', + device='cuda', + max_length=77 + ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl super().__init__() - self.tokenizer = T5Tokenizer.from_pretrained(version, cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models') - self.transformer = T5EncoderModel.from_pretrained(version, cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models') + self.tokenizer = T5Tokenizer.from_pretrained( + version, + cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models' + ) + self.transformer = T5EncoderModel.from_pretrained( + version, + cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models' + ) self.device = device - self.max_length = max_length # TODO: typical value? + self.max_length = max_length # TODO: typical value? self.freeze() def freeze(self): self.transformer = self.transformer.eval() - #self.train = disabled_train + # self.train = disabled_train for param in self.parameters(): param.requires_grad = False def forward(self, text): - batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, - return_overflowing_tokens=False, padding="max_length", return_tensors="pt") - tokens = batch_encoding["input_ids"].to(self.device) + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding='max_length', + return_tensors='pt') + tokens = batch_encoding['input_ids'].to(self.device) outputs = self.transformer(input_ids=tokens) z = outputs.last_hidden_state @@ -210,10 +277,9 @@ class FrozenT5Embedder(AbstractEncoder): def encode(self, text): return self(text) -from modelscope.models.cv.image_to_3d.ldm.thirdp.psp.id_loss import IDFeatures -import kornia.augmentation as K class FrozenFaceEncoder(AbstractEncoder): + def __init__(self, model_path, augment=False): super().__init__() self.loss_fn = IDFeatures(model_path) @@ -242,8 +308,8 @@ class FrozenFaceEncoder(AbstractEncoder): if self.augment is not None: # Transforms require 0-1 - img = self.augment((img + 1)/2) - img = 2*img - 1 + img = self.augment((img + 1) / 2) + img = 2 * img - 1 feat = self.loss_fn(img, crop=True) feat = self.mapper(feat.unsqueeze(1)) @@ -252,26 +318,43 @@ class FrozenFaceEncoder(AbstractEncoder): def encode(self, img): return self(img) + class FrozenCLIPEmbedder(AbstractEncoder): """Uses the CLIP transformer encoder for text (from huggingface)""" - def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): # clip-vit-base-patch32 + + def __init__(self, + version='openai/clip-vit-large-patch14', + device='cuda', + max_length=77): # clip-vit-base-patch32 super().__init__() - self.tokenizer = CLIPTokenizer.from_pretrained(version, cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models') - self.transformer = CLIPTextModel.from_pretrained(version, cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models') + self.tokenizer = CLIPTokenizer.from_pretrained( + version, + cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models' + ) + self.transformer = CLIPTextModel.from_pretrained( + version, + cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models' + ) self.device = device - self.max_length = max_length # TODO: typical value? + self.max_length = max_length # TODO: typical value? self.freeze() def freeze(self): self.transformer = self.transformer.eval() - #self.train = disabled_train + # self.train = disabled_train for param in self.parameters(): param.requires_grad = False def forward(self, text): - batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, - return_overflowing_tokens=False, padding="max_length", return_tensors="pt") - tokens = batch_encoding["input_ids"].to(self.device) + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding='max_length', + return_tensors='pt') + tokens = batch_encoding['input_ids'].to(self.device) outputs = self.transformer(input_ids=tokens) z = outputs.last_hidden_state @@ -280,36 +363,47 @@ class FrozenCLIPEmbedder(AbstractEncoder): def encode(self, text): return self(text) -import torch.nn.functional as F -from transformers import CLIPVisionModel + class ClipImageProjector(AbstractEncoder): """ Uses the CLIP image encoder. """ - def __init__(self, version="openai/clip-vit-large-patch14", max_length=77): # clip-vit-base-patch32 + + def __init__(self, + version='openai/clip-vit-large-patch14', + max_length=77): # clip-vit-base-patch32 super().__init__() self.model = CLIPVisionModel.from_pretrained(version) self.model.train() - self.max_length = max_length # TODO: typical value? + self.max_length = max_length # TODO: typical value? self.antialias = True self.mapper = torch.nn.Linear(1024, 768) - self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) - self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + self.register_buffer( + 'mean', + torch.Tensor([0.48145466, 0.4578275, 0.40821073]), + persistent=False) + self.register_buffer( + 'std', + torch.Tensor([0.26862954, 0.26130258, 0.27577711]), + persistent=False) null_cond = self.get_null_cond(version, max_length) self.register_buffer('null_cond', null_cond) @torch.no_grad() def get_null_cond(self, version, max_length): device = self.mean.device - embedder = FrozenCLIPEmbedder(version=version, device=device, max_length=max_length) - null_cond = embedder([""]) + embedder = FrozenCLIPEmbedder( + version=version, device=device, max_length=max_length) + null_cond = embedder(['']) return null_cond def preprocess(self, x): # Expects inputs in the range -1, 1 - x = kornia.geometry.resize(x, (224, 224), - interpolation='bicubic',align_corners=True, - antialias=self.antialias) + x = kornia.geometry.resize( + x, (224, 224), + interpolation='bicubic', + align_corners=True, + antialias=self.antialias) x = (x + 1.) / 2. # renormalize according to clip x = kornia.enhance.normalize(x, self.mean, self.std) @@ -323,15 +417,23 @@ class ClipImageProjector(AbstractEncoder): outputs = self.model(pixel_values=x) last_hidden_state = outputs.last_hidden_state last_hidden_state = self.mapper(last_hidden_state) - return F.pad(last_hidden_state, [0,0, 0,self.max_length-last_hidden_state.shape[1], 0,0]) + return F.pad( + last_hidden_state, + [0, 0, 0, self.max_length - last_hidden_state.shape[1], 0, 0]) def encode(self, im): return self(im) + class ProjectedFrozenCLIPEmbedder(AbstractEncoder): - def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): # clip-vit-base-patch32 + + def __init__(self, + version='openai/clip-vit-large-patch14', + device='cuda', + max_length=77): # clip-vit-base-patch32 super().__init__() - self.embedder = FrozenCLIPEmbedder(version=version, device=device, max_length=max_length) + self.embedder = FrozenCLIPEmbedder( + version=version, device=device, max_length=max_length) self.projection = torch.nn.Linear(768, 768) def forward(self, text): @@ -341,31 +443,41 @@ class ProjectedFrozenCLIPEmbedder(AbstractEncoder): def encode(self, text): return self(text) + class FrozenCLIPImageEmbedder(AbstractEncoder): """ Uses the CLIP image encoder. Not actually frozen... If you want that set cond_stage_trainable=False in cfg """ + def __init__( - self, - model='ViT-L/14', - jit=False, - device='cpu', - antialias=False, - ): + self, + model='ViT-L/14', + jit=False, + device='cpu', + antialias=False, + ): super().__init__() self.model, _ = clip.load(name=model, device=device, jit=jit) # We don't use the text part so delete it del self.model.transformer self.antialias = antialias - self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) - self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + self.register_buffer( + 'mean', + torch.Tensor([0.48145466, 0.4578275, 0.40821073]), + persistent=False) + self.register_buffer( + 'std', + torch.Tensor([0.26862954, 0.26130258, 0.27577711]), + persistent=False) def preprocess(self, x): # Expects inputs in the range -1, 1 - x = kornia.geometry.resize(x, (224, 224), - interpolation='bicubic',align_corners=True, - antialias=self.antialias) + x = kornia.geometry.resize( + x, (224, 224), + interpolation='bicubic', + align_corners=True, + antialias=self.antialias) x = (x + 1.) / 2. # renormalize according to clip x = kornia.enhance.normalize(x, self.mean, self.std) @@ -382,35 +494,41 @@ class FrozenCLIPImageEmbedder(AbstractEncoder): def encode(self, im): return self(im).unsqueeze(1) -from torchvision import transforms -import random class FrozenCLIPImageMutliEmbedder(AbstractEncoder): """ Uses the CLIP image encoder. Not actually frozen... If you want that set cond_stage_trainable=False in cfg """ + def __init__( - self, - model='ViT-L/14', - jit=False, - device='cpu', - antialias=True, - max_crops=5, - ): + self, + model='ViT-L/14', + jit=False, + device='cpu', + antialias=True, + max_crops=5, + ): super().__init__() self.model, _ = clip.load(name=model, device=device, jit=jit) # We don't use the text part so delete it del self.model.transformer self.antialias = antialias - self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) - self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + self.register_buffer( + 'mean', + torch.Tensor([0.48145466, 0.4578275, 0.40821073]), + persistent=False) + self.register_buffer( + 'std', + torch.Tensor([0.26862954, 0.26130258, 0.27577711]), + persistent=False) self.max_crops = max_crops def preprocess(self, x): # Expects inputs in the range -1, 1 - randcrop = transforms.RandomResizedCrop(224, scale=(0.085, 1.0), ratio=(1,1)) + randcrop = transforms.RandomResizedCrop( + 224, scale=(0.085, 1.0), ratio=(1, 1)) max_crops = self.max_crops patches = [] crops = [randcrop(x) for _ in range(max_crops)] @@ -441,7 +559,9 @@ class FrozenCLIPImageMutliEmbedder(AbstractEncoder): def encode(self, im): return self(im) + class SpatialRescaler(nn.Module): + def __init__(self, n_stages=1, method='bilinear', @@ -452,19 +572,24 @@ class SpatialRescaler(nn.Module): super().__init__() self.n_stages = n_stages assert self.n_stages >= 0 - assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] + assert method in [ + 'nearest', 'linear', 'bilinear', 'trilinear', 'bicubic', 'area' + ] self.multiplier = multiplier - self.interpolator = partial(torch.nn.functional.interpolate, mode=method) + self.interpolator = partial( + torch.nn.functional.interpolate, mode=method) self.remap_output = out_channels is not None if self.remap_output: - print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') - self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) + print( + f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.' + ) + self.channel_mapper = nn.Conv2d( + in_channels, out_channels, 1, bias=bias) - def forward(self,x): + def forward(self, x): for stage in range(self.n_stages): x = self.interpolator(x, scale_factor=self.multiplier) - if self.remap_output: x = self.channel_mapper(x) return x @@ -473,25 +598,38 @@ class SpatialRescaler(nn.Module): return self(x) -from modelscope.models.cv.image_to_3d.ldm.util import instantiate_from_config -from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like - - class LowScaleEncoder(nn.Module): - def __init__(self, model_config, linear_start, linear_end, timesteps=1000, max_noise_level=250, output_size=64, + + def __init__(self, + model_config, + linear_start, + linear_end, + timesteps=1000, + max_noise_level=250, + output_size=64, scale_factor=1.0): super().__init__() self.max_noise_level = max_noise_level self.model = instantiate_from_config(model_config) - self.augmentation_schedule = self.register_schedule(timesteps=timesteps, linear_start=linear_start, - linear_end=linear_end) + self.augmentation_schedule = self.register_schedule( + timesteps=timesteps, + linear_start=linear_start, + linear_end=linear_end) self.out_size = output_size self.scale_factor = scale_factor - def register_schedule(self, beta_schedule="linear", timesteps=1000, - linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): - betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, - cosine_s=cosine_s) + def register_schedule(self, + beta_schedule='linear', + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3): + betas = make_beta_schedule( + beta_schedule, + timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s) alphas = 1. - betas alphas_cumprod = np.cumprod(alphas, axis=0) alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) @@ -500,33 +638,45 @@ class LowScaleEncoder(nn.Module): self.num_timesteps = int(timesteps) self.linear_start = linear_start self.linear_end = linear_end - assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + assert alphas_cumprod.shape[ + 0] == self.num_timesteps, 'alphas have to be defined for each timestep' to_torch = partial(torch.tensor, dtype=torch.float32) self.register_buffer('betas', to_torch(betas)) self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + self.register_buffer('alphas_cumprod_prev', + to_torch(alphas_cumprod_prev)) # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) - self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) - self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) - self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) - self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + self.register_buffer('sqrt_alphas_cumprod', + to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', + to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', + to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', + to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', + to_torch(np.sqrt(1. / alphas_cumprod - 1))) def q_sample(self, x_start, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) - return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) + * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, + x_start.shape) * noise) def forward(self, x): z = self.model.encode(x).sample() z = z * self.scale_factor - noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() + noise_level = torch.randint( + 0, self.max_noise_level, (x.shape[0], ), device=x.device).long() z = self.q_sample(z, noise_level) if self.out_size is not None: - z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest") # TODO: experiment with mode + z = torch.nn.functional.interpolate( + z, size=self.out_size, + mode='nearest') # TODO: experiment with mode # z = z.repeat_interleave(2, -2).repeat_interleave(2, -1) return z, noise_level @@ -535,10 +685,13 @@ class LowScaleEncoder(nn.Module): return self.model.decode(z) -if __name__ == "__main__": +if __name__ == '__main__': from ldm.util import count_params - sentences = ["a hedgehog drinking a whiskey", "der mond ist aufgegangen", "Ein Satz mit vielen Sonderzeichen: äöü ß ?! : 'xx-y/@s'"] - model = FrozenT5Embedder(version="google/t5-v1_1-xl").cuda() + sentences = [ + 'a hedgehog drinking a whiskey', 'der mond ist aufgegangen', + "Ein Satz mit vielen Sonderzeichen: äöü ß ?! : 'xx-y/@s'" + ] + model = FrozenT5Embedder(version='google/t5-v1_1-xl').cuda() count_params(model, True) z = model(sentences) print(z.shape) @@ -548,4 +701,4 @@ if __name__ == "__main__": z = model(sentences) print(z.shape) - print("done.") + print('done.') diff --git a/modelscope/models/cv/image_to_3d/ldm/modules/x_transformer.py b/modelscope/models/cv/image_to_3d/ldm/modules/x_transformer.py index 5fc15bf9..0e5d7b8f 100644 --- a/modelscope/models/cv/image_to_3d/ldm/modules/x_transformer.py +++ b/modelscope/models/cv/image_to_3d/ldm/modules/x_transformer.py @@ -1,28 +1,26 @@ """shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers""" -import torch -from torch import nn, einsum -import torch.nn.functional as F +from collections import namedtuple from functools import partial from inspect import isfunction -from collections import namedtuple -from einops import rearrange, repeat, reduce + +import torch +import torch.nn.functional as F +from einops import rearrange, reduce, repeat +from torch import einsum, nn # constants DEFAULT_DIM_HEAD = 64 -Intermediates = namedtuple('Intermediates', [ - 'pre_softmax_attn', - 'post_softmax_attn' -]) +Intermediates = namedtuple('Intermediates', + ['pre_softmax_attn', 'post_softmax_attn']) -LayerIntermediates = namedtuple('Intermediates', [ - 'hiddens', - 'attn_intermediates' -]) +LayerIntermediates = namedtuple('Intermediates', + ['hiddens', 'attn_intermediates']) class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): super().__init__() self.emb = nn.Embedding(max_seq_len, dim) @@ -37,13 +35,15 @@ class AbsolutePositionalEmbedding(nn.Module): class FixedPositionalEmbedding(nn.Module): + def __init__(self, dim): super().__init__() - inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + inv_freq = 1. / (10000**(torch.arange(0, dim, 2).float() / dim)) self.register_buffer('inv_freq', inv_freq) def forward(self, x, seq_dim=1, offset=0): - t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset + t = torch.arange( + x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) return emb[None, :, :] @@ -51,6 +51,7 @@ class FixedPositionalEmbedding(nn.Module): # helpers + def exists(val): return val is not None @@ -62,20 +63,26 @@ def default(val, d): def always(val): + def inner(*args, **kwargs): return val + return inner def not_equals(val): + def inner(x): return x != val + return inner def equals(val): + def inner(x): return x == val + return inner @@ -85,6 +92,7 @@ def max_neg_value(tensor): # keyword argument helpers + def pick_and_pop(keys, d): values = list(map(lambda key: d.pop(key), keys)) return dict(zip(keys, values)) @@ -96,7 +104,7 @@ def group_dict_by_key(cond, d): match = bool(cond(key)) ind = int(not match) return_val[ind][key] = d[key] - return (*return_val,) + return (*return_val, ) def string_begins_with(prefix, str): @@ -108,13 +116,17 @@ def group_by_key_prefix(prefix, d): def groupby_prefix_and_trim(prefix, d): - kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) - kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) + kwargs_with_prefix, kwargs = group_dict_by_key( + partial(string_begins_with, prefix), d) + kwargs_without_prefix = dict( + map(lambda x: (x[0][len(prefix):], x[1]), + tuple(kwargs_with_prefix.items()))) return kwargs_without_prefix, kwargs # classes class Scale(nn.Module): + def __init__(self, value, fn): super().__init__() self.value = value @@ -126,6 +138,7 @@ class Scale(nn.Module): class Rezero(nn.Module): + def __init__(self, fn): super().__init__() self.fn = fn @@ -137,9 +150,10 @@ class Rezero(nn.Module): class ScaleNorm(nn.Module): + def __init__(self, dim, eps=1e-5): super().__init__() - self.scale = dim ** -0.5 + self.scale = dim**-0.5 self.eps = eps self.g = nn.Parameter(torch.ones(1)) @@ -149,9 +163,10 @@ class ScaleNorm(nn.Module): class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-8): super().__init__() - self.scale = dim ** -0.5 + self.scale = dim**-0.5 self.eps = eps self.g = nn.Parameter(torch.ones(dim)) @@ -161,11 +176,13 @@ class RMSNorm(nn.Module): class Residual(nn.Module): + def forward(self, x, residual): return x + residual class GRUGating(nn.Module): + def __init__(self, dim): super().__init__() self.gru = nn.GRUCell(dim, dim) @@ -173,15 +190,16 @@ class GRUGating(nn.Module): def forward(self, x, residual): gated_output = self.gru( rearrange(x, 'b n d -> (b n) d'), - rearrange(residual, 'b n d -> (b n) d') - ) + rearrange(residual, 'b n d -> (b n) d')) return gated_output.reshape_as(x) # feedforward + class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): super().__init__() self.proj = nn.Linear(dim_in, dim_out * 2) @@ -192,20 +210,16 @@ class GEGLU(nn.Module): class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): super().__init__() inner_dim = int(dim * mult) dim_out = default(dim_out, dim) - project_in = nn.Sequential( - nn.Linear(dim, inner_dim), - nn.GELU() - ) if not glu else GEGLU(dim, inner_dim) + project_in = nn.Sequential(nn.Linear( + dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) - self.net = nn.Sequential( - project_in, - nn.Dropout(dropout), - nn.Linear(inner_dim, dim_out) - ) + self.net = nn.Sequential(project_in, nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out)) def forward(self, x): return self.net(x) @@ -213,24 +227,24 @@ class FeedForward(nn.Module): # attention. class Attention(nn.Module): - def __init__( - self, - dim, - dim_head=DEFAULT_DIM_HEAD, - heads=8, - causal=False, - mask=None, - talking_heads=False, - sparse_topk=None, - use_entmax15=False, - num_mem_kv=0, - dropout=0., - on_attn=False - ): + + def __init__(self, + dim, + dim_head=DEFAULT_DIM_HEAD, + heads=8, + causal=False, + mask=None, + talking_heads=False, + sparse_topk=None, + use_entmax15=False, + num_mem_kv=0, + dropout=0., + on_attn=False): super().__init__() if use_entmax15: - raise NotImplementedError("Check out entmax activation instead of softmax activation!") - self.scale = dim_head ** -0.5 + raise NotImplementedError( + 'Check out entmax activation instead of softmax activation!') + self.scale = dim_head**-0.5 self.heads = heads self.causal = causal self.mask = mask @@ -252,7 +266,7 @@ class Attention(nn.Module): self.sparse_topk = sparse_topk # entmax - #self.attn_fn = entmax15 if use_entmax15 else F.softmax + # self.attn_fn = entmax15 if use_entmax15 else F.softmax self.attn_fn = F.softmax # add memory key / values @@ -263,19 +277,19 @@ class Attention(nn.Module): # attention on attention self.attn_on_attn = on_attn - self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim) + self.to_out = nn.Sequential(nn.Linear( + inner_dim, dim + * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim) - def forward( - self, - x, - context=None, - mask=None, - context_mask=None, - rel_pos=None, - sinusoidal_emb=None, - prev_attn=None, - mem=None - ): + def forward(self, + x, + context=None, + mask=None, + context_mask=None, + rel_pos=None, + sinusoidal_emb=None, + prev_attn=None, + mem=None): b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device kv_input = default(context, x) @@ -297,23 +311,29 @@ class Attention(nn.Module): k = self.to_k(k_input) v = self.to_v(v_input) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), + (q, k, v)) input_mask = None if any(map(exists, (mask, context_mask))): - q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) + q_mask = default(mask, lambda: torch.ones( + (b, n), device=device).bool()) k_mask = q_mask if not exists(context) else context_mask - k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) + k_mask = default( + k_mask, lambda: torch.ones( + (b, k.shape[-2]), device=device).bool()) q_mask = rearrange(q_mask, 'b i -> b () i ()') k_mask = rearrange(k_mask, 'b j -> b () () j') input_mask = q_mask * k_mask if self.num_mem_kv > 0: - mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v)) + mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), + (self.mem_k, self.mem_v)) k = torch.cat((mem_k, k), dim=-2) v = torch.cat((mem_v, v), dim=-2) if exists(input_mask): - input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) + input_mask = F.pad( + input_mask, (self.num_mem_kv, 0), value=True) dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale mask_value = max_neg_value(dots) @@ -324,7 +344,8 @@ class Attention(nn.Module): pre_softmax_attn = dots if talking_heads: - dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous() + dots = einsum('b h i j, h k -> b k i j', dots, + self.pre_softmax_proj).contiguous() if exists(rel_pos): dots = rel_pos(dots) @@ -336,7 +357,8 @@ class Attention(nn.Module): if self.causal: i, j = dots.shape[-2:] r = torch.arange(i, device=device) - mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j') + mask = rearrange(r, 'i -> () () i ()') < rearrange( + r, 'j -> () () () j') mask = F.pad(mask, (j - i, 0), value=False) dots.masked_fill_(mask, mask_value) del mask @@ -354,59 +376,61 @@ class Attention(nn.Module): attn = self.dropout(attn) if talking_heads: - attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous() + attn = einsum('b h i j, h k -> b k i j', attn, + self.post_softmax_proj).contiguous() out = einsum('b h i j, b h j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') intermediates = Intermediates( pre_softmax_attn=pre_softmax_attn, - post_softmax_attn=post_softmax_attn - ) + post_softmax_attn=post_softmax_attn) return self.to_out(out), intermediates class AttentionLayers(nn.Module): - def __init__( - self, - dim, - depth, - heads=8, - causal=False, - cross_attend=False, - only_cross=False, - use_scalenorm=False, - use_rmsnorm=False, - use_rezero=False, - rel_pos_num_buckets=32, - rel_pos_max_distance=128, - position_infused_attn=False, - custom_layers=None, - sandwich_coef=None, - par_ratio=None, - residual_attn=False, - cross_residual_attn=False, - macaron=False, - pre_norm=True, - gate_residual=False, - **kwargs - ): + + def __init__(self, + dim, + depth, + heads=8, + causal=False, + cross_attend=False, + only_cross=False, + use_scalenorm=False, + use_rmsnorm=False, + use_rezero=False, + rel_pos_num_buckets=32, + rel_pos_max_distance=128, + position_infused_attn=False, + custom_layers=None, + sandwich_coef=None, + par_ratio=None, + residual_attn=False, + cross_residual_attn=False, + macaron=False, + pre_norm=True, + gate_residual=False, + **kwargs): super().__init__() ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs) - dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) + # dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) self.dim = dim self.depth = depth self.layers = nn.ModuleList([]) self.has_pos_emb = position_infused_attn - self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None + self.pia_pos_emb = FixedPositionalEmbedding( + dim) if position_infused_attn else None self.rotary_pos_emb = always(None) - assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' + assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than \ + the relative position max distance' + self.rel_pos = None self.pre_norm = pre_norm @@ -429,7 +453,7 @@ class AttentionLayers(nn.Module): default_block = ('a', 'f') if macaron: - default_block = ('f',) + default_block + default_block = ('f', ) + default_block if exists(custom_layers): layer_types = custom_layers @@ -440,13 +464,17 @@ class AttentionLayers(nn.Module): par_attn = par_depth // par_ratio depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper par_width = (depth_cut + depth_cut // par_attn) // par_attn - assert len(default_block) <= par_width, 'default block is too large for par_ratio' - par_block = default_block + ('f',) * (par_width - len(default_block)) + assert len( + default_block + ) <= par_width, 'default block is too large for par_ratio' + par_block = default_block + ('f', ) * ( + par_width - len(default_block)) par_head = par_block * par_attn - layer_types = par_head + ('f',) * (par_depth - len(par_head)) + layer_types = par_head + ('f', ) * (par_depth - len(par_head)) elif exists(sandwich_coef): assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' - layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef + layer_types = ('a', ) * sandwich_coef + default_block * ( + depth - sandwich_coef) + ('f', ) * sandwich_coef else: layer_types = default_block * depth @@ -455,7 +483,8 @@ class AttentionLayers(nn.Module): for layer_type in self.layer_types: if layer_type == 'a': - layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) + layer = Attention( + dim, heads=heads, causal=causal, **attn_kwargs) elif layer_type == 'c': layer = Attention(dim, heads=heads, **attn_kwargs) elif layer_type == 'f': @@ -472,21 +501,15 @@ class AttentionLayers(nn.Module): else: residual_fn = Residual() - self.layers.append(nn.ModuleList([ - norm_fn(), - layer, - residual_fn - ])) + self.layers.append(nn.ModuleList([norm_fn(), layer, residual_fn])) - def forward( - self, - x, - context=None, - mask=None, - context_mask=None, - mems=None, - return_hiddens=False - ): + def forward(self, + x, + context=None, + mask=None, + context_mask=None, + mems=None, + return_hiddens=False): hiddens = [] intermediates = [] prev_attn = None @@ -494,7 +517,8 @@ class AttentionLayers(nn.Module): mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers - for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): + for ind, (layer_type, (norm, block, residual_fn)) in enumerate( + zip(self.layer_types, self.layers)): is_last = ind == (len(self.layers) - 1) if layer_type == 'a': @@ -507,10 +531,20 @@ class AttentionLayers(nn.Module): x = norm(x) if layer_type == 'a': - out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos, - prev_attn=prev_attn, mem=layer_mem) + out, inter = block( + x, + mask=mask, + sinusoidal_emb=self.pia_pos_emb, + rel_pos=self.rel_pos, + prev_attn=prev_attn, + mem=layer_mem) elif layer_type == 'c': - out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn) + out, inter = block( + x, + context=context, + mask=mask, + context_mask=context_mask, + prev_attn=prev_cross_attn) elif layer_type == 'f': out = block(x) @@ -529,9 +563,7 @@ class AttentionLayers(nn.Module): if return_hiddens: intermediates = LayerIntermediates( - hiddens=hiddens, - attn_intermediates=intermediates - ) + hiddens=hiddens, attn_intermediates=intermediates) return x, intermediates @@ -539,28 +571,29 @@ class AttentionLayers(nn.Module): class Encoder(AttentionLayers): + def __init__(self, **kwargs): assert 'causal' not in kwargs, 'cannot set causality on encoder' super().__init__(causal=False, **kwargs) - class TransformerWrapper(nn.Module): - def __init__( - self, - *, - num_tokens, - max_seq_len, - attn_layers, - emb_dim=None, - max_mem_len=0., - emb_dropout=0., - num_memory_tokens=None, - tie_embedding=False, - use_pos_emb=True - ): + + def __init__(self, + *, + num_tokens, + max_seq_len, + attn_layers, + emb_dim=None, + max_mem_len=0., + emb_dropout=0., + num_memory_tokens=None, + tie_embedding=False, + use_pos_emb=True): super().__init__() - assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' + assert isinstance( + attn_layers, AttentionLayers + ), 'attention layers must be one of Encoder or Decoder' dim = attn_layers.dim emb_dim = default(emb_dim, dim) @@ -571,22 +604,26 @@ class TransformerWrapper(nn.Module): self.token_emb = nn.Embedding(num_tokens, emb_dim) self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( - use_pos_emb and not attn_layers.has_pos_emb) else always(0) + use_pos_emb and not attn_layers.has_pos_emb) else always(0) self.emb_dropout = nn.Dropout(emb_dropout) - self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() + self.project_emb = nn.Linear(emb_dim, + dim) if emb_dim != dim else nn.Identity() self.attn_layers = attn_layers self.norm = nn.LayerNorm(dim) self.init_() - self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() + self.to_logits = nn.Linear( + dim, num_tokens + ) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() # memory tokens (like [cls]) from Memory Transformers paper num_memory_tokens = default(num_memory_tokens, 0) self.num_memory_tokens = num_memory_tokens if num_memory_tokens > 0: - self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) + self.memory_tokens = nn.Parameter( + torch.randn(num_memory_tokens, dim)) # let funnel encoder know number of memory tokens, if specified if hasattr(attn_layers, 'num_memory_tokens'): @@ -595,17 +632,15 @@ class TransformerWrapper(nn.Module): def init_(self): nn.init.normal_(self.token_emb.weight, std=0.02) - def forward( - self, - x, - return_embeddings=False, - mask=None, - return_mems=False, - return_attn=False, - mems=None, - **kwargs - ): - b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens + def forward(self, + x, + return_embeddings=False, + mask=None, + return_mems=False, + return_attn=False, + mems=None, + **kwargs): + b, _, _, num_mem = *x.shape, x.device, self.num_memory_tokens x = self.token_emb(x) x += self.pos_emb(x) x = self.emb_dropout(x) @@ -620,7 +655,8 @@ class TransformerWrapper(nn.Module): if exists(mask): mask = F.pad(mask, (num_mem, 0), value=True) - x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) + x, intermediates = self.attn_layers( + x, mask=mask, mems=mems, return_hiddens=True, **kwargs) x = self.norm(x) mem, x = x[:, :num_mem], x[:, num_mem:] @@ -629,13 +665,18 @@ class TransformerWrapper(nn.Module): if return_mems: hiddens = intermediates.hiddens - new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens - new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems)) + new_mems = list( + map(lambda pair: torch.cat(pair, dim=-2), zip( + mems, hiddens))) if exists(mems) else hiddens + new_mems = list( + map(lambda t: t[..., -self.max_mem_len:, :].detach(), + new_mems)) return out, new_mems if return_attn: - attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) + attn_maps = list( + map(lambda t: t.post_softmax_attn, + intermediates.attn_intermediates)) return out, attn_maps return out - diff --git a/modelscope/models/cv/image_to_3d/ldm/thirdp/psp/helpers.py b/modelscope/models/cv/image_to_3d/ldm/thirdp/psp/helpers.py index 983baaa5..954db9cd 100644 --- a/modelscope/models/cv/image_to_3d/ldm/thirdp/psp/helpers.py +++ b/modelscope/models/cv/image_to_3d/ldm/thirdp/psp/helpers.py @@ -1,121 +1,133 @@ # https://github.com/eladrich/pixel2style2pixel from collections import namedtuple -import torch -from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module -""" -ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) -""" +import torch +from torch.nn import (AdaptiveAvgPool2d, BatchNorm2d, Conv2d, MaxPool2d, + Module, PReLU, ReLU, Sequential, Sigmoid) + +# ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) class Flatten(Module): - def forward(self, input): - return input.view(input.size(0), -1) + + def forward(self, input): + return input.view(input.size(0), -1) def l2_norm(input, axis=1): - norm = torch.norm(input, 2, axis, True) - output = torch.div(input, norm) - return output + norm = torch.norm(input, 2, axis, True) + output = torch.div(input, norm) + return output class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): - """ A named tuple describing a ResNet block. """ + """ A named tuple describing a ResNet block. """ def get_block(in_channel, depth, num_units, stride=2): - return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] + return [Bottleneck(in_channel, depth, stride) + ] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] def get_blocks(num_layers): - if num_layers == 50: - blocks = [ - get_block(in_channel=64, depth=64, num_units=3), - get_block(in_channel=64, depth=128, num_units=4), - get_block(in_channel=128, depth=256, num_units=14), - get_block(in_channel=256, depth=512, num_units=3) - ] - elif num_layers == 100: - blocks = [ - get_block(in_channel=64, depth=64, num_units=3), - get_block(in_channel=64, depth=128, num_units=13), - get_block(in_channel=128, depth=256, num_units=30), - get_block(in_channel=256, depth=512, num_units=3) - ] - elif num_layers == 152: - blocks = [ - get_block(in_channel=64, depth=64, num_units=3), - get_block(in_channel=64, depth=128, num_units=8), - get_block(in_channel=128, depth=256, num_units=36), - get_block(in_channel=256, depth=512, num_units=3) - ] - else: - raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) - return blocks + if num_layers == 50: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=4), + get_block(in_channel=128, depth=256, num_units=14), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 100: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=13), + get_block(in_channel=128, depth=256, num_units=30), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 152: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=8), + get_block(in_channel=128, depth=256, num_units=36), + get_block(in_channel=256, depth=512, num_units=3) + ] + else: + raise ValueError( + 'Invalid number of layers: {}. Must be one of [50, 100, 152]'. + format(num_layers)) + return blocks class SEModule(Module): - def __init__(self, channels, reduction): - super(SEModule, self).__init__() - self.avg_pool = AdaptiveAvgPool2d(1) - self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) - self.relu = ReLU(inplace=True) - self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) - self.sigmoid = Sigmoid() - def forward(self, x): - module_input = x - x = self.avg_pool(x) - x = self.fc1(x) - x = self.relu(x) - x = self.fc2(x) - x = self.sigmoid(x) - return module_input * x + def __init__(self, channels, reduction): + super(SEModule, self).__init__() + self.avg_pool = AdaptiveAvgPool2d(1) + self.fc1 = Conv2d( + channels, + channels // reduction, + kernel_size=1, + padding=0, + bias=False) + self.relu = ReLU(inplace=True) + self.fc2 = Conv2d( + channels // reduction, + channels, + kernel_size=1, + padding=0, + bias=False) + self.sigmoid = Sigmoid() + + def forward(self, x): + module_input = x + x = self.avg_pool(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.sigmoid(x) + return module_input * x class bottleneck_IR(Module): - def __init__(self, in_channel, depth, stride): - super(bottleneck_IR, self).__init__() - if in_channel == depth: - self.shortcut_layer = MaxPool2d(1, stride) - else: - self.shortcut_layer = Sequential( - Conv2d(in_channel, depth, (1, 1), stride, bias=False), - BatchNorm2d(depth) - ) - self.res_layer = Sequential( - BatchNorm2d(in_channel), - Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), - Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) - ) - def forward(self, x): - shortcut = self.shortcut_layer(x) - res = self.res_layer(x) - return res + shortcut + def __init__(self, in_channel, depth, stride): + super(bottleneck_IR, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth)) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), + PReLU(depth), Conv2d(depth, depth, (3, 3), stride, 1, bias=False), + BatchNorm2d(depth)) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut class bottleneck_IR_SE(Module): - def __init__(self, in_channel, depth, stride): - super(bottleneck_IR_SE, self).__init__() - if in_channel == depth: - self.shortcut_layer = MaxPool2d(1, stride) - else: - self.shortcut_layer = Sequential( - Conv2d(in_channel, depth, (1, 1), stride, bias=False), - BatchNorm2d(depth) - ) - self.res_layer = Sequential( - BatchNorm2d(in_channel), - Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), - PReLU(depth), - Conv2d(depth, depth, (3, 3), stride, 1, bias=False), - BatchNorm2d(depth), - SEModule(depth, 16) - ) - def forward(self, x): - shortcut = self.shortcut_layer(x) - res = self.res_layer(x) - return res + shortcut \ No newline at end of file + def __init__(self, in_channel, depth, stride): + super(bottleneck_IR_SE, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth)) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), + PReLU(depth), Conv2d(depth, depth, (3, 3), stride, 1, bias=False), + BatchNorm2d(depth), SEModule(depth, 16)) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut diff --git a/modelscope/models/cv/image_to_3d/ldm/thirdp/psp/id_loss.py b/modelscope/models/cv/image_to_3d/ldm/thirdp/psp/id_loss.py index 16dc0dc7..c6cb52bc 100644 --- a/modelscope/models/cv/image_to_3d/ldm/thirdp/psp/id_loss.py +++ b/modelscope/models/cv/image_to_3d/ldm/thirdp/psp/id_loss.py @@ -1,22 +1,26 @@ # https://github.com/eladrich/pixel2style2pixel import torch from torch import nn + from modelscope.models.cv.image_to_3d.ldm.thirdp.psp.model_irse import Backbone class IDFeatures(nn.Module): + def __init__(self, model_path): super(IDFeatures, self).__init__() print('Loading ResNet ArcFace') - self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') - self.facenet.load_state_dict(torch.load(model_path, map_location="cpu")) + self.facenet = Backbone( + input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') + self.facenet.load_state_dict( + torch.load(model_path, map_location='cpu')) self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) self.facenet.eval() def forward(self, x, crop=False): # Not sure of the image range here if crop: - x = torch.nn.functional.interpolate(x, (256, 256), mode="area") + x = torch.nn.functional.interpolate(x, (256, 256), mode='area') x = x[:, :, 35:223, 32:220] x = self.face_pool(x) x_feats = self.facenet(x) diff --git a/modelscope/models/cv/image_to_3d/ldm/thirdp/psp/model_irse.py b/modelscope/models/cv/image_to_3d/ldm/thirdp/psp/model_irse.py index 6fe5f241..f3d6deab 100644 --- a/modelscope/models/cv/image_to_3d/ldm/thirdp/psp/model_irse.py +++ b/modelscope/models/cv/image_to_3d/ldm/thirdp/psp/model_irse.py @@ -1,86 +1,96 @@ # https://github.com/eladrich/pixel2style2pixel -from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module -from modelscope.models.cv.image_to_3d.ldm.thirdp.psp.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm +from torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear, + Module, PReLU, Sequential) -""" -Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) -""" +from modelscope.models.cv.image_to_3d.ldm.thirdp.psp.helpers import ( + Flatten, bottleneck_IR, bottleneck_IR_SE, get_blocks, l2_norm) + +# Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) class Backbone(Module): - def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): - super(Backbone, self).__init__() - assert input_size in [112, 224], "input_size should be 112 or 224" - assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" - assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" - blocks = get_blocks(num_layers) - if mode == 'ir': - unit_module = bottleneck_IR - elif mode == 'ir_se': - unit_module = bottleneck_IR_SE - self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), - BatchNorm2d(64), - PReLU(64)) - if input_size == 112: - self.output_layer = Sequential(BatchNorm2d(512), - Dropout(drop_ratio), - Flatten(), - Linear(512 * 7 * 7, 512), - BatchNorm1d(512, affine=affine)) - else: - self.output_layer = Sequential(BatchNorm2d(512), - Dropout(drop_ratio), - Flatten(), - Linear(512 * 14 * 14, 512), - BatchNorm1d(512, affine=affine)) - modules = [] - for block in blocks: - for bottleneck in block: - modules.append(unit_module(bottleneck.in_channel, - bottleneck.depth, - bottleneck.stride)) - self.body = Sequential(*modules) + def __init__(self, + input_size, + num_layers, + mode='ir', + drop_ratio=0.4, + affine=True): + super(Backbone, self).__init__() + assert input_size in [112, 224], 'input_size should be 112 or 224' + assert num_layers in [50, 100, + 152], 'num_layers should be 50, 100 or 152' + assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' + blocks = get_blocks(num_layers) + if mode == 'ir': + unit_module = bottleneck_IR + elif mode == 'ir_se': + unit_module = bottleneck_IR_SE + self.input_layer = Sequential( + Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), + PReLU(64)) + if input_size == 112: + self.output_layer = Sequential( + BatchNorm2d(512), Dropout(drop_ratio), Flatten(), + Linear(512 * 7 * 7, 512), BatchNorm1d(512, affine=affine)) + else: + self.output_layer = Sequential( + BatchNorm2d(512), Dropout(drop_ratio), Flatten(), + Linear(512 * 14 * 14, 512), BatchNorm1d(512, affine=affine)) - def forward(self, x): - x = self.input_layer(x) - x = self.body(x) - x = self.output_layer(x) - return l2_norm(x) + modules = [] + for block in blocks: + for bottleneck in block: + modules.append( + unit_module(bottleneck.in_channel, bottleneck.depth, + bottleneck.stride)) + self.body = Sequential(*modules) + + def forward(self, x): + x = self.input_layer(x) + x = self.body(x) + x = self.output_layer(x) + return l2_norm(x) def IR_50(input_size): - """Constructs a ir-50 model.""" - model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) - return model + """Constructs a ir-50 model.""" + model = Backbone( + input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) + return model def IR_101(input_size): - """Constructs a ir-101 model.""" - model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) - return model + """Constructs a ir-101 model.""" + model = Backbone( + input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) + return model def IR_152(input_size): - """Constructs a ir-152 model.""" - model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) - return model + """Constructs a ir-152 model.""" + model = Backbone( + input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) + return model def IR_SE_50(input_size): - """Constructs a ir_se-50 model.""" - model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) - return model + """Constructs a ir_se-50 model.""" + model = Backbone( + input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) + return model def IR_SE_101(input_size): - """Constructs a ir_se-101 model.""" - model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) - return model + """Constructs a ir_se-101 model.""" + model = Backbone( + input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) + return model def IR_SE_152(input_size): - """Constructs a ir_se-152 model.""" - model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) - return model \ No newline at end of file + """Constructs a ir_se-152 model.""" + model = Backbone( + input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) + return model diff --git a/modelscope/models/cv/image_to_3d/ldm/util.py b/modelscope/models/cv/image_to_3d/ldm/util.py index d27bfee5..83ac20a3 100644 --- a/modelscope/models/cv/image_to_3d/ldm/util.py +++ b/modelscope/models/cv/image_to_3d/ldm/util.py @@ -1,32 +1,28 @@ import importlib - -import torchvision -import torch -from torch import optim -import numpy as np - -from inspect import isfunction -from PIL import Image, ImageDraw, ImageFont - import os -import numpy as np -import matplotlib.pyplot as plt -from PIL import Image -import torch import time +from inspect import isfunction + import cv2 +import matplotlib.pyplot as plt +import numpy as np import PIL +import torch +import torchvision +from PIL import Image, ImageDraw, ImageFont +from torch import optim + def pil_rectangle_crop(im): - width, height = im.size # Get dimensions - + width, height = im.size # Get dimensions + if width <= height: left = 0 right = width - top = (height - width)/2 - bottom = (height + width)/2 + top = (height - width) / 2 + bottom = (height + width) / 2 else: - + top = 0 bottom = height left = (width - height) / 2 @@ -36,6 +32,7 @@ def pil_rectangle_crop(im): im = im.crop((left, top, right, bottom)) return im + def add_margin(pil_img, color=0, size=256): width, height = pil_img.size result = Image.new(pil_img.mode, (size, size), color) @@ -46,16 +43,17 @@ def add_margin(pil_img, color=0, size=256): def create_carvekit_interface(): from carvekit.api.high import HiInterface # Check doc strings for more information - interface = HiInterface(object_type="object", # Can be "object" or "hairs-like". - batch_size_seg=5, - batch_size_matting=1, - device='cuda' if torch.cuda.is_available() else 'cpu', - seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net - matting_mask_size=2048, - trimap_prob_threshold=231, - trimap_dilation=30, - trimap_erosion_iters=5, - fp16=False) + interface = HiInterface( + object_type='object', # Can be "object" or "hairs-like". + batch_size_seg=5, + batch_size_matting=1, + device='cuda' if torch.cuda.is_available() else 'cpu', + seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net + matting_mask_size=2048, + trimap_prob_threshold=231, + trimap_dilation=30, + trimap_erosion_iters=5, + fp16=False) return interface @@ -72,17 +70,17 @@ def load_and_preprocess(interface, input_im): image_without_background = np.array(image_without_background) est_seg = image_without_background > 127 image = np.array(image) - foreground = est_seg[:, : , -1].astype(np.bool_) + foreground = est_seg[:, :, -1].astype(np.bool_) image[~foreground] = [255., 255., 255.] x, y, w, h = cv2.boundingRect(foreground.astype(np.uint8)) - image = image[y:y+h, x:x+w, :] + image = image[y:y + h, x:x + w, :] image = PIL.Image.fromarray(np.array(image)) - + # resize image such that long edge is 512 image.thumbnail([200, 200], Image.LANCZOS) image = add_margin(image, (255, 255, 255), size=256) image = np.array(image) - + return image @@ -92,16 +90,17 @@ def log_txt_as_img(wh, xc, size=10): b = len(xc) txts = list() for bi in range(b): - txt = Image.new("RGB", wh, color="white") + txt = Image.new('RGB', wh, color='white') draw = ImageDraw.Draw(txt) font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) nc = int(40 * (wh[0] / 256)) - lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) + lines = '\n'.join(xc[bi][start:start + nc] + for start in range(0, len(xc[bi]), nc)) try: - draw.text((0, 0), lines, fill="black", font=font) + draw.text((0, 0), lines, fill='black', font=font) except UnicodeEncodeError: - print("Cant encode string for logging. Skipping.") + print('Cant encode string for logging. Skipping.') txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 txts.append(txt) @@ -117,7 +116,7 @@ def ismap(x): def isimage(x): - if not isinstance(x,torch.Tensor): + if not isinstance(x, torch.Tensor): return False return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) @@ -143,22 +142,24 @@ def mean_flat(tensor): def count_params(model, verbose=False): total_params = sum(p.numel() for p in model.parameters()) if verbose: - print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") + print( + f'{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.' + ) return total_params def instantiate_from_config(config): - if not "target" in config: + if 'target' not in config: if config == '__is_first_stage__': return None - elif config == "__is_unconditional__": + elif config == '__is_unconditional__': return None - raise KeyError("Expected key `target` to instantiate.") - return get_obj_from_str(config["target"])(**config.get("params", dict())) + raise KeyError('Expected key `target` to instantiate.') + return get_obj_from_str(config['target'])(**config.get('params', dict())) def get_obj_from_str(string, reload=False): - module, cls = string.rsplit(".", 1) + module, cls = string.rsplit('.', 1) print(module) if reload: module_imp = importlib.import_module(module) @@ -168,25 +169,43 @@ def get_obj_from_str(string, reload=False): class AdamWwithEMAandWings(optim.Optimizer): # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 - def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using - weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code - ema_power=1., param_names=()): + def __init__( + self, # noqa + params, # noqa + lr=1.e-3, # noqa + betas=(0.9, 0.999), # noqa + eps=1.e-8, # noqa + weight_decay=1.e-2, # noqa + amsgrad=False, # noqa + ema_decay=0.9999, # ema decay to match previous code # noqa + ema_power=1., # noqa + param_names=()): # noqa + # TODO: check hyperparameters before using """AdamW that saves EMA versions of the parameters.""" if not 0.0 <= lr: - raise ValueError("Invalid learning rate: {}".format(lr)) + raise ValueError('Invalid learning rate: {}'.format(lr)) if not 0.0 <= eps: - raise ValueError("Invalid epsilon value: {}".format(eps)) + raise ValueError('Invalid epsilon value: {}'.format(eps)) if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + raise ValueError('Invalid beta parameter at index 0: {}'.format( + betas[0])) if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + raise ValueError('Invalid beta parameter at index 1: {}'.format( + betas[1])) if not 0.0 <= weight_decay: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError( + 'Invalid weight_decay value: {}'.format(weight_decay)) if not 0.0 <= ema_decay <= 1.0: - raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) - defaults = dict(lr=lr, betas=betas, eps=eps, - weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, - ema_power=ema_power, param_names=param_names) + raise ValueError('Invalid ema_decay value: {}'.format(ema_decay)) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad, + ema_decay=ema_decay, + ema_power=ema_power, + param_names=param_names) super().__init__(params, defaults) def __setstate__(self, state): @@ -212,7 +231,7 @@ class AdamWwithEMAandWings(optim.Optimizer): exp_avgs = [] exp_avg_sqs = [] ema_params_with_grad = [] - state_sums = [] + # state_sums = [] max_exp_avg_sqs = [] state_steps = [] amsgrad = group['amsgrad'] @@ -225,7 +244,8 @@ class AdamWwithEMAandWings(optim.Optimizer): continue params_with_grad.append(p) if p.grad.is_sparse: - raise RuntimeError('AdamW does not support sparse gradients') + raise RuntimeError( + 'AdamW does not support sparse gradients') grads.append(p.grad) state = self.state[p] @@ -234,12 +254,15 @@ class AdamWwithEMAandWings(optim.Optimizer): if len(state) == 0: state['step'] = 0 # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + state['exp_avg'] = torch.zeros_like( + p, memory_format=torch.preserve_format) # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + state['exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format) if amsgrad: # Maintains max of all exp. moving avg. of sq. grad. values - state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + state['max_exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format) # Exponential moving average of parameter values state['param_exp_avg'] = p.detach().float().clone() @@ -255,22 +278,25 @@ class AdamWwithEMAandWings(optim.Optimizer): # record the step after step update state_steps.append(state['step']) - optim._functional.adamw(params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=amsgrad, - beta1=beta1, - beta2=beta2, - lr=group['lr'], - weight_decay=group['weight_decay'], - eps=group['eps'], - maximize=False) + optim._functional.adamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + maximize=False) - cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) - for param, ema_param in zip(params_with_grad, ema_params_with_grad): - ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) + cur_ema_decay = min(ema_decay, 1 - state['step']**-ema_power) + for param, ema_param in zip(params_with_grad, + ema_params_with_grad): + ema_param.mul_(cur_ema_decay).add_( + param.float(), alpha=1 - cur_ema_decay) - return loss \ No newline at end of file + return loss diff --git a/modelscope/models/cv/video_frame_interpolation/rife/IFNet_HDv3.py b/modelscope/models/cv/video_frame_interpolation/rife/IFNet_HDv3.py index 957f9653..e904aad2 100644 --- a/modelscope/models/cv/video_frame_interpolation/rife/IFNet_HDv3.py +++ b/modelscope/models/cv/video_frame_interpolation/rife/IFNet_HDv3.py @@ -5,85 +5,117 @@ import torch import torch.nn as nn import torch.nn.functional as F + from .warplayer import warp -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): - return nn.Sequential( - nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, - padding=padding, dilation=dilation, bias=True), - nn.PReLU(out_planes) - ) -def conv_bn(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): +def conv(in_planes, + out_planes, + kernel_size=3, + stride=1, + padding=1, + dilation=1): return nn.Sequential( - nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, - padding=padding, dilation=dilation, bias=False), - nn.BatchNorm2d(out_planes), - nn.PReLU(out_planes) - ) + nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=True), nn.PReLU(out_planes)) + + +def conv_bn(in_planes, + out_planes, + kernel_size=3, + stride=1, + padding=1, + dilation=1): + return nn.Sequential( + nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=False), nn.BatchNorm2d(out_planes), nn.PReLU(out_planes)) + class IFBlock(nn.Module): + def __init__(self, in_planes, c=64): super(IFBlock, self).__init__() self.conv0 = nn.Sequential( - conv(in_planes, c//2, 3, 2, 1), - conv(c//2, c, 3, 2, 1), - ) - self.convblock0 = nn.Sequential( - conv(c, c), - conv(c, c) - ) - self.convblock1 = nn.Sequential( - conv(c, c), - conv(c, c) - ) - self.convblock2 = nn.Sequential( - conv(c, c), - conv(c, c) - ) - self.convblock3 = nn.Sequential( - conv(c, c), - conv(c, c) + conv(in_planes, c // 2, 3, 2, 1), + conv(c // 2, c, 3, 2, 1), ) + self.convblock0 = nn.Sequential(conv(c, c), conv(c, c)) + self.convblock1 = nn.Sequential(conv(c, c), conv(c, c)) + self.convblock2 = nn.Sequential(conv(c, c), conv(c, c)) + self.convblock3 = nn.Sequential(conv(c, c), conv(c, c)) self.conv1 = nn.Sequential( - nn.ConvTranspose2d(c, c//2, 4, 2, 1), - nn.PReLU(c//2), - nn.ConvTranspose2d(c//2, 4, 4, 2, 1), + nn.ConvTranspose2d(c, c // 2, 4, 2, 1), + nn.PReLU(c // 2), + nn.ConvTranspose2d(c // 2, 4, 4, 2, 1), ) self.conv2 = nn.Sequential( - nn.ConvTranspose2d(c, c//2, 4, 2, 1), - nn.PReLU(c//2), - nn.ConvTranspose2d(c//2, 1, 4, 2, 1), + nn.ConvTranspose2d(c, c // 2, 4, 2, 1), + nn.PReLU(c // 2), + nn.ConvTranspose2d(c // 2, 1, 4, 2, 1), ) def forward(self, x, flow, scale=1): - x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) - flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 1. / scale + x = F.interpolate( + x, + scale_factor=1. / scale, + mode='bilinear', + align_corners=False, + recompute_scale_factor=False) + flow = F.interpolate( + flow, + scale_factor=1. / scale, + mode='bilinear', + align_corners=False, + recompute_scale_factor=False) * 1. / scale feat = self.conv0(torch.cat((x, flow), 1)) feat = self.convblock0(feat) + feat feat = self.convblock1(feat) + feat feat = self.convblock2(feat) + feat - feat = self.convblock3(feat) + feat + feat = self.convblock3(feat) + feat flow = self.conv1(feat) mask = self.conv2(feat) - flow = F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * scale - mask = F.interpolate(mask, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) + flow = F.interpolate( + flow, + scale_factor=scale, + mode='bilinear', + align_corners=False, + recompute_scale_factor=False) * scale + mask = F.interpolate( + mask, + scale_factor=scale, + mode='bilinear', + align_corners=False, + recompute_scale_factor=False) return flow, mask - + + class IFNet(nn.Module): + def __init__(self): super(IFNet, self).__init__() - self.block0 = IFBlock(7+4, c=90) - self.block1 = IFBlock(7+4, c=90) - self.block2 = IFBlock(7+4, c=90) - self.block_tea = IFBlock(10+4, c=90) + self.block0 = IFBlock(7 + 4, c=90) + self.block1 = IFBlock(7 + 4, c=90) + self.block2 = IFBlock(7 + 4, c=90) + self.block_tea = IFBlock(10 + 4, c=90) # self.contextnet = Contextnet() # self.unet = Unet() def forward(self, x, scale_list=[4, 2, 1], training=False): - if training == False: + if training is False: channel = x.shape[1] // 2 img0 = x[:, :channel] img1 = x[:, channel:] @@ -94,11 +126,17 @@ class IFNet(nn.Module): warped_img1 = img1 flow = (x[:, :4]).detach() * 0 mask = (x[:, :1]).detach() * 0 - loss_cons = 0 + # loss_cons = 0 block = [self.block0, self.block1, self.block2] for i in range(3): - f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i]) - f1, m1 = block[i](torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), scale=scale_list[i]) + f0, m0 = block[i]( + torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), + flow, + scale=scale_list[i]) + f1, m1 = block[i]( + torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1), + torch.cat((flow[:, 2:4], flow[:, :2]), 1), + scale=scale_list[i]) flow = flow + (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2 mask = mask + (m0 + (-m1)) / 2 mask_list.append(mask) @@ -114,6 +152,7 @@ class IFNet(nn.Module): ''' for i in range(3): mask_list[i] = torch.sigmoid(mask_list[i]) - merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i]) - # merged[i] = torch.clamp(merged[i] + res, 0, 1) + merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * ( + 1 - mask_list[i]) + # merged[i] = torch.clamp(merged[i] + res, 0, 1) return flow_list, mask_list[2], merged diff --git a/modelscope/models/cv/video_frame_interpolation/rife/RIFE_HDv3.py b/modelscope/models/cv/video_frame_interpolation/rife/RIFE_HDv3.py index 359d573a..090b7cd7 100644 --- a/modelscope/models/cv/video_frame_interpolation/rife/RIFE_HDv3.py +++ b/modelscope/models/cv/video_frame_interpolation/rife/RIFE_HDv3.py @@ -2,17 +2,15 @@ # originally MIT License, Copyright (c) Megvii Inc., # and publicly available at https://github.com/megvii-research/ECCV2022-RIFE +import itertools + +import numpy as np import torch import torch.nn as nn -import numpy as np -from torch.optim import AdamW -import torch.optim as optim -import itertools -from .warplayer import warp -from torch.nn.parallel import DistributedDataParallel as DDP -from .IFNet_HDv3 import * import torch.nn.functional as F -from .loss import * +import torch.optim as optim +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import AdamW from modelscope.metainfo import Models from modelscope.models.base import Tensor @@ -21,15 +19,23 @@ from modelscope.models.builder import MODELS from modelscope.utils.config import Config from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.logger import get_logger +from .IFNet_HDv3 import * +from .loss import * +from .warplayer import warp -@MODELS.register_module(Tasks.video_frame_interpolation, module_name=Models.rife) + +@MODELS.register_module( + Tasks.video_frame_interpolation, module_name=Models.rife) class RIFEModel(TorchModel): + def __init__(self, model_dir, *args, **kwargs): super().__init__(model_dir, *args, **kwargs) - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.device = torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu') self.flownet = IFNet() self.flownet.to(self.device) - self.optimG = AdamW(self.flownet.parameters(), lr=1e-6, weight_decay=1e-4) + self.optimG = AdamW( + self.flownet.parameters(), lr=1e-6, weight_decay=1e-4) self.epe = EPE() # self.vgg = VGGPerceptualLoss().to(device) self.sobel = SOBEL() @@ -43,62 +49,76 @@ class RIFEModel(TorchModel): self.flownet.eval() def load_model(self, path, rank=0): + def convert(param): if rank == -1: return { - k.replace("module.", ""): v - for k, v in param.items() - if "module." in k + k.replace('module.', ''): v + for k, v in param.items() if 'module.' in k } else: return param + if rank <= 0: if torch.cuda.is_available(): - self.flownet.load_state_dict(convert(torch.load('{}/flownet.pkl'.format(path)))) + self.flownet.load_state_dict( + convert(torch.load('{}/flownet.pkl'.format(path)))) else: - self.flownet.load_state_dict(convert(torch.load('{}/flownet.pkl'.format(path), map_location ='cpu'))) - + self.flownet.load_state_dict( + convert( + torch.load( + '{}/flownet.pkl'.format(path), + map_location='cpu'))) + def save_model(self, path, rank=0): if rank == 0: - torch.save(self.flownet.state_dict(),'{}/flownet.pkl'.format(path)) + torch.save(self.flownet.state_dict(), + '{}/flownet.pkl'.format(path)) def inference(self, img0, img1, scale=1.0): imgs = torch.cat((img0, img1), 1) - scale_list = [4/scale, 2/scale, 1/scale] + scale_list = [4 / scale, 2 / scale, 1 / scale] _, _, merged = self.flownet(imgs, scale_list) return merged[2].detach() - + def forward(self, inputs): img0 = inputs['img0'] img1 = inputs['img1'] scale = inputs['scale'] return {'output': self.inference(img0, img1, scale)} - def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None): + def update(self, + imgs, + gt, + learning_rate=0, + mul=1, + training=True, + flow_gt=None): for param_group in self.optimG.param_groups: param_group['lr'] = learning_rate - img0 = imgs[:, :3] - img1 = imgs[:, 3:] + # img0 = imgs[:, :3] + # img1 = imgs[:, 3:] if training: self.train() else: self.eval() scale = [4, 2, 1] - flow, mask, merged = self.flownet(torch.cat((imgs, gt), 1), scale=scale, training=training) + flow, mask, merged = self.flownet( + torch.cat((imgs, gt), 1), scale=scale, training=training) loss_l1 = (merged[2] - gt).abs().mean() - loss_smooth = self.sobel(flow[2], flow[2]*0).mean() + loss_smooth = self.sobel(flow[2], flow[2] * 0).mean() # loss_vgg = self.vgg(merged[2], gt) if training: self.optimG.zero_grad() loss_G = loss_cons + loss_smooth * 0.1 loss_G.backward() self.optimG.step() - else: - flow_teacher = flow[2] + # else: + # flow_teacher = flow[2] return merged[2], { 'mask': mask, 'flow': flow[2][:, :2], 'loss_l1': loss_l1, 'loss_cons': loss_cons, 'loss_smooth': loss_smooth, - } + } diff --git a/modelscope/models/cv/video_frame_interpolation/rife/__init__.py b/modelscope/models/cv/video_frame_interpolation/rife/__init__.py index a1d5b148..af475199 100644 --- a/modelscope/models/cv/video_frame_interpolation/rife/__init__.py +++ b/modelscope/models/cv/video_frame_interpolation/rife/__init__.py @@ -2,4 +2,4 @@ # originally MIT License, Copyright (c) Megvii Inc., # and publicly available at https://github.com/megvii-research/ECCV2022-RIFE -from .RIFE_HDv3 import RIFEModel \ No newline at end of file +from .RIFE_HDv3 import RIFEModel diff --git a/modelscope/models/cv/video_frame_interpolation/rife/loss.py b/modelscope/models/cv/video_frame_interpolation/rife/loss.py index 62f19baf..97f7644c 100644 --- a/modelscope/models/cv/video_frame_interpolation/rife/loss.py +++ b/modelscope/models/cv/video_frame_interpolation/rife/loss.py @@ -2,26 +2,28 @@ # originally MIT License, Copyright (c) Megvii Inc., # and publicly available at https://github.com/megvii-research/ECCV2022-RIFE -import torch import numpy as np +import torch import torch.nn as nn import torch.nn.functional as F import torchvision.models as models -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') class EPE(nn.Module): + def __init__(self): super(EPE, self).__init__() def forward(self, flow, gt, loss_mask): - loss_map = (flow - gt.detach()) ** 2 - loss_map = (loss_map.sum(1, True) + 1e-6) ** 0.5 + loss_map = (flow - gt.detach())**2 + loss_map = (loss_map.sum(1, True) + 1e-6)**0.5 return (loss_map * loss_mask) class Ternary(nn.Module): + def __init__(self): super(Ternary, self).__init__() patch_size = 7 @@ -43,7 +45,7 @@ class Ternary(nn.Module): return gray def hamming(self, t1, t2): - dist = (t1 - t2) ** 2 + dist = (t1 - t2)**2 dist_norm = torch.mean(dist / (0.1 + dist), 1, True) return dist_norm @@ -60,6 +62,7 @@ class Ternary(nn.Module): class SOBEL(nn.Module): + def __init__(self): super(SOBEL, self).__init__() self.kernelX = torch.tensor([ @@ -74,17 +77,20 @@ class SOBEL(nn.Module): def forward(self, pred, gt): N, C, H, W = pred.shape[0], pred.shape[1], pred.shape[2], pred.shape[3] img_stack = torch.cat( - [pred.reshape(N*C, 1, H, W), gt.reshape(N*C, 1, H, W)], 0) + [pred.reshape(N * C, 1, H, W), + gt.reshape(N * C, 1, H, W)], 0) sobel_stack_x = F.conv2d(img_stack, self.kernelX, padding=1) sobel_stack_y = F.conv2d(img_stack, self.kernelY, padding=1) - pred_X, gt_X = sobel_stack_x[:N*C], sobel_stack_x[N*C:] - pred_Y, gt_Y = sobel_stack_y[:N*C], sobel_stack_y[N*C:] + pred_X, gt_X = sobel_stack_x[:N * C], sobel_stack_x[N * C:] + pred_Y, gt_Y = sobel_stack_y[:N * C], sobel_stack_y[N * C:] - L1X, L1Y = torch.abs(pred_X-gt_X), torch.abs(pred_Y-gt_Y) - loss = (L1X+L1Y) + L1X, L1Y = torch.abs(pred_X - gt_X), torch.abs(pred_Y - gt_Y) + loss = (L1X + L1Y) return loss + class MeanShift(nn.Conv2d): + def __init__(self, data_mean, data_std, data_range=1, norm=True): c = len(data_mean) super(MeanShift, self).__init__(c, c, kernel_size=1) @@ -98,14 +104,19 @@ class MeanShift(nn.Conv2d): self.weight.data.mul_(std.view(c, 1, 1, 1)) self.bias.data = data_range * torch.Tensor(data_mean) self.requires_grad = False - + + class VGGPerceptualLoss(torch.nn.Module): + def __init__(self, rank=0): super(VGGPerceptualLoss, self).__init__() - blocks = [] + # blocks = [] pretrained = True - self.vgg_pretrained_features = models.vgg19(pretrained=pretrained).features - self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda() + self.vgg_pretrained_features = models.vgg19( + pretrained=pretrained).features + self.normalize = MeanShift([0.485, 0.456, 0.406], + [0.229, 0.224, 0.225], + norm=True).cuda() for param in self.parameters(): param.requires_grad = False @@ -113,20 +124,21 @@ class VGGPerceptualLoss(torch.nn.Module): X = self.normalize(X) Y = self.normalize(Y) indices = [2, 7, 12, 21, 30] - weights = [1.0/2.6, 1.0/4.8, 1.0/3.7, 1.0/5.6, 10/1.5] + weights = [1.0 / 2.6, 1.0 / 4.8, 1.0 / 3.7, 1.0 / 5.6, 10 / 1.5] k = 0 loss = 0 for i in range(indices[-1]): X = self.vgg_pretrained_features[i](X) Y = self.vgg_pretrained_features[i](Y) - if (i+1) in indices: + if (i + 1) in indices: loss += weights[k] * (X - Y.detach()).abs().mean() * 0.1 k += 1 return loss + if __name__ == '__main__': img0 = torch.zeros(3, 3, 256, 256).float().to(device) - img1 = torch.tensor(np.random.normal( - 0, 1, (3, 3, 256, 256))).float().to(device) + img1 = torch.tensor(np.random.normal(0, 1, + (3, 3, 256, 256))).float().to(device) ternary_loss = Ternary() print(ternary_loss(img0, img1).shape) diff --git a/modelscope/models/cv/video_frame_interpolation/rife/warplayer.py b/modelscope/models/cv/video_frame_interpolation/rife/warplayer.py index 9a3f8eff..e4440e6f 100644 --- a/modelscope/models/cv/video_frame_interpolation/rife/warplayer.py +++ b/modelscope/models/cv/video_frame_interpolation/rife/warplayer.py @@ -5,22 +5,36 @@ import torch import torch.nn as nn -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') backwarp_tenGrid = {} def warp(tenInput, tenFlow): k = (str(tenFlow.device), str(tenFlow.size())) if k not in backwarp_tenGrid: - tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view( - 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) - tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view( - 1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) - backwarp_tenGrid[k] = torch.cat( - [tenHorizontal, tenVertical], 1).to(device) + tenHorizontal = torch.linspace( + -1.0, 1.0, tenFlow.shape[3], device=device).view( + 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, + tenFlow.shape[2], -1) + tenVertical = torch.linspace( + -1.0, 1.0, tenFlow.shape[2], + device=device).view(1, 1, tenFlow.shape[2], + 1).expand(tenFlow.shape[0], -1, -1, + tenFlow.shape[3]) + backwarp_tenGrid[k] = torch.cat([tenHorizontal, tenVertical], + 1).to(device) - tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), - tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1) + tenFlow = torch.cat( + [ + tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), # no qa + tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) + ], + 1) # no qa g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) - return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True) + return torch.nn.functional.grid_sample( + input=tenInput, + grid=g, + mode='bilinear', + padding_mode='border', + align_corners=True) diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index 30c5e484..b9bc1d17 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -293,7 +293,9 @@ else: ], 'human3d_render_pipeline': ['Human3DRenderPipeline'], 'human3d_animation_pipeline': ['Human3DAnimationPipeline'], - 'rife_video_frame_interpolation_pipeline': ['RIFEVideoFrameInterpolationPipeline'], + 'rife_video_frame_interpolation_pipeline': [ + 'RIFEVideoFrameInterpolationPipeline' + ], 'anydoor_pipeline': ['AnydoorPipeline'], } diff --git a/modelscope/pipelines/cv/image_to_3d_pipeline.py b/modelscope/pipelines/cv/image_to_3d_pipeline.py index 3dcd2de3..d74003d6 100644 --- a/modelscope/pipelines/cv/image_to_3d_pipeline.py +++ b/modelscope/pipelines/cv/image_to_3d_pipeline.py @@ -1,28 +1,27 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os.path as osp from typing import Any, Dict -import rembg + import cv2 import numpy as np import PIL +import rembg import torch import torch.nn.functional as F import torchvision.transforms as T import torchvision.transforms.functional as TF +from omegaconf import OmegaConf from PIL import Image from torchvision.utils import save_image -from omegaconf import OmegaConf + # import modelscope.models.cv.image_to_image_generation.data as data # import modelscope.models.cv.image_to_image_generation.models as models # import modelscope.models.cv.image_to_image_generation.ops as ops from modelscope.metainfo import Pipelines -# from modelscope.models.cv.image_to_3d.model import UNet -# from modelscope.models.cv.image_to_image_generation.models.clip import \ -# VisionTransformer - -from modelscope.models.cv.image_to_3d.ldm.models.diffusion.sync_dreamer import SyncMultiviewDiffusion -from modelscope.models.cv.image_to_3d.ldm.util import instantiate_from_config, add_margin - +# from modelscope.models.cv.image_to_3d.ldm.models.diffusion.sync_dreamer import \ +# SyncMultiviewDiffusion +from modelscope.models.cv.image_to_3d.ldm.util import (add_margin, + instantiate_from_config) from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.builder import PIPELINES @@ -31,23 +30,29 @@ from modelscope.utils.config import Config from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.logger import get_logger +# from modelscope.models.cv.image_to_3d.model import UNet +# from modelscope.models.cv.image_to_image_generation.models.clip import \ +# VisionTransformer + logger = get_logger() + # Load Syncdreamer Model def load_model(cfg, ckpt, strict=True): config = OmegaConf.load(cfg) model = instantiate_from_config(config.model) print(f'loading model from {ckpt} ...') - ckpt = torch.load(ckpt,map_location='cpu') - model.load_state_dict(ckpt['state_dict'],strict=strict) + ckpt = torch.load(ckpt, map_location='cpu') + model.load_state_dict(ckpt['state_dict'], strict=strict) model = model.cuda().eval() return model + # Prepare Syncdreamer Input def prepare_inputs(image_input, elevation_input, crop_size=-1, image_size=256): - image_input[:,:,:3] = image_input[:,:,:3][:,:,::-1] + image_input[:, :, :3] = image_input[:, :, :3][:, :, ::-1] image_input = Image.fromarray(image_input) - if crop_size!=-1: + if crop_size != -1: alpha_np = np.asarray(image_input)[:, :, 3] coords = np.stack(np.nonzero(alpha_np), 1)[:, (1, 0)] min_x, min_y = np.min(coords, 0) @@ -59,21 +64,26 @@ def prepare_inputs(image_input, elevation_input, crop_size=-1, image_size=256): ref_img_ = ref_img_.resize((w_, h_), resample=Image.BICUBIC) image_input = add_margin(ref_img_, size=image_size) else: - image_input = add_margin(image_input, size=max(image_input.height, image_input.width)) - image_input = image_input.resize((image_size, image_size), resample=Image.BICUBIC) + image_input = add_margin( + image_input, size=max(image_input.height, image_input.width)) + image_input = image_input.resize((image_size, image_size), + resample=Image.BICUBIC) image_input = np.asarray(image_input) image_input = image_input.astype(np.float32) / 255.0 ref_mask = image_input[:, :, 3:] - image_input[:, :, :3] = image_input[:, :, :3] * ref_mask + 1 - ref_mask # white background + image_input[:, :, : + 3] = image_input[:, :, : + 3] * ref_mask + 1 - ref_mask # white background image_input = image_input[:, :, :3] * 2.0 - 1.0 image_input = torch.from_numpy(image_input.astype(np.float32)) - elevation_input = torch.from_numpy(np.asarray([np.deg2rad(elevation_input)], np.float32)) - return {"input_image": image_input, "input_elevation": elevation_input} + elevation_input = torch.from_numpy( + np.asarray([np.deg2rad(elevation_input)], np.float32)) + return {'input_image': image_input, 'input_elevation': elevation_input} + @PIPELINES.register_module( - Tasks.image_to_3d, - module_name=Pipelines.image_to_3d) + Tasks.image_to_3d, module_name=Pipelines.image_to_3d) class Image23DPipeline(Pipeline): def __init__(self, model: str, **kwargs): @@ -91,23 +101,28 @@ class Image23DPipeline(Pipeline): self._device = torch.device('cuda') else: self._device = torch.device('cpu') - ckpt = config_path.replace("configuration.json", "syncdreamer-pretrain.ckpt") - self.model = load_model(config_path.replace("configuration.json", "syncdreamer.yaml"), ckpt).to(self._device) + ckpt = config_path.replace('configuration.json', + 'syncdreamer-pretrain.ckpt') + self.model = load_model( + config_path.replace('configuration.json', 'syncdreamer.yaml'), + ckpt).to(self._device) # os.system("pip install -r {}".format(config_path.replace("configuration.json", "requirements.txt"))) # assert isinstance(self.model, SyncMultiviewDiffusion) def preprocess(self, input: Input) -> Dict[str, Any]: - + result = rembg.remove(Image.open(input)) print(type(result)) img = np.array(result) - img[:,:,:3] = img[:,:,:3][:,:,::-1] + img[:, :, :3] = img[:, :, :3][:, :, ::-1] # img = cv2.imread(input) - data = prepare_inputs(img, elevation_input=10, crop_size=200, image_size=256) - - for k,v in data.items(): + data = prepare_inputs( + img, elevation_input=10, crop_size=200, image_size=256) + + for k, v in data.items(): data[k] = v.unsqueeze(0).cuda() - data[k] = torch.repeat_interleave(data[k], 1, dim=0) # only one sample + data[k] = torch.repeat_interleave( + data[k], 1, dim=0) # only one sample return data @torch.no_grad() @@ -115,11 +130,11 @@ class Image23DPipeline(Pipeline): x_sample = self.model.sample(input, 2.0, 8) B, N, _, H, W = x_sample.shape - x_sample = (torch.clamp(x_sample,max=1.0,min=-1.0) + 1) * 0.5 - x_sample = x_sample.permute(0,1,3,4,2).cpu().numpy() * 255 + x_sample = (torch.clamp(x_sample, max=1.0, min=-1.0) + 1) * 0.5 + x_sample = x_sample.permute(0, 1, 3, 4, 2).cpu().numpy() * 255 x_sample = x_sample.astype(np.uint8) - show_in_im2 = [Image.fromarray(x_sample[0,ni]) for ni in range(N)] - return {'MViews':show_in_im2} + show_in_im2 = [Image.fromarray(x_sample[0, ni]) for ni in range(N)] + return {'MViews': show_in_im2} def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: return inputs diff --git a/modelscope/pipelines/cv/rife_video_frame_interpolation_pipeline.py b/modelscope/pipelines/cv/rife_video_frame_interpolation_pipeline.py index 1f50fee8..a4892273 100644 --- a/modelscope/pipelines/cv/rife_video_frame_interpolation_pipeline.py +++ b/modelscope/pipelines/cv/rife_video_frame_interpolation_pipeline.py @@ -46,6 +46,7 @@ class RIFEVideoFrameInterpolationPipeline(Pipeline): >>> print('pipeline: the output video path is {}'.format(result)) """ + def __init__(self, model: Union[RIFEModel, str], preprocessor=None, @@ -75,7 +76,7 @@ class RIFEVideoFrameInterpolationPipeline(Pipeline): def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: inputs = input['video'] - fps = input['fps'] + # fps = input['fps'] out_fps = input['out_fps'] video_len = len(inputs) diff --git a/tests/pipelines/test_image_to_3d.py b/tests/pipelines/test_image_to_3d.py index d4de345c..d909f71e 100644 --- a/tests/pipelines/test_image_to_3d.py +++ b/tests/pipelines/test_image_to_3d.py @@ -3,6 +3,7 @@ import unittest import numpy as np from PIL import Image + from modelscope.outputs import OutputKeys from modelscope.pipelines import pipeline from modelscope.pipelines.base import Pipeline @@ -24,11 +25,11 @@ class ImageTo3DTest(unittest.TestCase): def pipeline_inference(self, pipeline: Pipeline, input: str): result = pipeline(input['input_path']) np_content = [] - for idx,img in enumerate(result['MViews']): + for idx, img in enumerate(result['MViews']): np_content.append(np.array(result['MViews'][idx])) np_content = np.concatenate(np_content, axis=1) - Image.fromarray(np_content).save("./concat.png") + Image.fromarray(np_content).save('./concat.png') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_modelhub(self): @@ -38,4 +39,4 @@ class ImageTo3DTest(unittest.TestCase): if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tests/pipelines/test_rife_video_frame_interpolation.py b/tests/pipelines/test_rife_video_frame_interpolation.py index 5ff28451..78949e44 100644 --- a/tests/pipelines/test_rife_video_frame_interpolation.py +++ b/tests/pipelines/test_rife_video_frame_interpolation.py @@ -5,7 +5,6 @@ import unittest from modelscope.hub.snapshot_download import snapshot_download from modelscope.models import Model from modelscope.outputs import OutputKeys -from modelscope.pipelines import pipeline from modelscope.pipelines.cv import RIFEVideoFrameInterpolationPipeline from modelscope.utils.constant import Tasks from modelscope.utils.test_utils import test_level