diff --git a/modelscope/models/multi_modal/image_to_video/modules/unet_i2v.py b/modelscope/models/multi_modal/image_to_video/modules/unet_i2v.py old mode 100755 new mode 100644 index eb258ccc..dae226a5 --- a/modelscope/models/multi_modal/image_to_video/modules/unet_i2v.py +++ b/modelscope/models/multi_modal/image_to_video/modules/unet_i2v.py @@ -6,6 +6,8 @@ import os import torch import torch.nn as nn import torch.nn.functional as F +import xformers +import xformers.ops from einops import rearrange from fairscale.nn.checkpoint import checkpoint_wrapper from rotary_embedding_torch import RotaryEmbedding @@ -51,6 +53,57 @@ def prob_mask_like(shape, prob, device): return mask +class MemoryEfficientCrossAttention(nn.Module): + + def __init__(self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.heads = heads + self.dim_head = dim_head + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.attention_op: Optional[Any] = None + + def forward(self, x, context=None, mask=None): + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3).reshape(b, t.shape[ + 1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape( + b * self.heads, t.shape[1], self.dim_head).contiguous(), + (q, k, v), + ) + + out = xformers.ops.memory_efficient_attention( + q, k, v, attn_bias=None, op=self.attention_op) + + if exists(mask): + raise NotImplementedError + out = ( + out.unsqueeze(0).reshape( + b, self.heads, out.shape[1], + self.dim_head).permute(0, 2, 1, + 3).reshape(b, out.shape[1], + self.heads * self.dim_head)) + return self.to_out(out) + + class RelativePositionBias(nn.Module): def __init__(self, heads=8, num_buckets=32, max_distance=128): @@ -242,7 +295,7 @@ class BasicTransformerBlock(nn.Module): disable_self_attn=False): super().__init__() - attn_cls = CrossAttention + attn_cls = MemoryEfficientCrossAttention self.disable_self_attn = disable_self_attn self.attn1 = attn_cls( query_dim=dim, @@ -1401,6 +1454,10 @@ class Img2VidSDUNet(nn.Module): module = checkpoint_wrapper( module) if self.use_checkpoint else module x = module(x, context) + elif isinstance(module, MemoryEfficientCrossAttention): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = module(x, context) elif isinstance(module, BasicTransformerBlock): module = checkpoint_wrapper( module) if self.use_checkpoint else module