add xformer in unet_i2v.py

This commit is contained in:
kangzhao2
2023-08-18 19:41:35 +08:00
parent 5f0f63eb37
commit 8ff95a5dc8

View File

@@ -6,6 +6,8 @@ import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import xformers
import xformers.ops
from einops import rearrange
from fairscale.nn.checkpoint import checkpoint_wrapper
from rotary_embedding_torch import RotaryEmbedding
@@ -51,6 +53,57 @@ def prob_mask_like(shape, prob, device):
return mask
class MemoryEfficientCrossAttention(nn.Module):
def __init__(self,
query_dim,
context_dim=None,
heads=8,
dim_head=64,
dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.heads = heads
self.dim_head = dim_head
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None
def forward(self, x, context=None, mask=None):
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
b, _, _ = q.shape
q, k, v = map(
lambda t: t.unsqueeze(3).reshape(b, t.shape[
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
b * self.heads, t.shape[1], self.dim_head).contiguous(),
(q, k, v),
)
out = xformers.ops.memory_efficient_attention(
q, k, v, attn_bias=None, op=self.attention_op)
if exists(mask):
raise NotImplementedError
out = (
out.unsqueeze(0).reshape(
b, self.heads, out.shape[1],
self.dim_head).permute(0, 2, 1,
3).reshape(b, out.shape[1],
self.heads * self.dim_head))
return self.to_out(out)
class RelativePositionBias(nn.Module):
def __init__(self, heads=8, num_buckets=32, max_distance=128):
@@ -242,7 +295,7 @@ class BasicTransformerBlock(nn.Module):
disable_self_attn=False):
super().__init__()
attn_cls = CrossAttention
attn_cls = MemoryEfficientCrossAttention
self.disable_self_attn = disable_self_attn
self.attn1 = attn_cls(
query_dim=dim,
@@ -1401,6 +1454,10 @@ class Img2VidSDUNet(nn.Module):
module = checkpoint_wrapper(
module) if self.use_checkpoint else module
x = module(x, context)
elif isinstance(module, MemoryEfficientCrossAttention):
module = checkpoint_wrapper(
module) if self.use_checkpoint else module
x = module(x, context)
elif isinstance(module, BasicTransformerBlock):
module = checkpoint_wrapper(
module) if self.use_checkpoint else module