mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
add xformer in unet_i2v.py
This commit is contained in:
59
modelscope/models/multi_modal/image_to_video/modules/unet_i2v.py
Executable file → Normal file
59
modelscope/models/multi_modal/image_to_video/modules/unet_i2v.py
Executable file → Normal file
@@ -6,6 +6,8 @@ import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import xformers
|
||||
import xformers.ops
|
||||
from einops import rearrange
|
||||
from fairscale.nn.checkpoint import checkpoint_wrapper
|
||||
from rotary_embedding_torch import RotaryEmbedding
|
||||
@@ -51,6 +53,57 @@ def prob_mask_like(shape, prob, device):
|
||||
return mask
|
||||
|
||||
|
||||
class MemoryEfficientCrossAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
query_dim,
|
||||
context_dim=None,
|
||||
heads=8,
|
||||
dim_head=64,
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
||||
self.attention_op: Optional[Any] = None
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
b, _, _ = q.shape
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3).reshape(b, t.shape[
|
||||
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
|
||||
b * self.heads, t.shape[1], self.dim_head).contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
out = xformers.ops.memory_efficient_attention(
|
||||
q, k, v, attn_bias=None, op=self.attention_op)
|
||||
|
||||
if exists(mask):
|
||||
raise NotImplementedError
|
||||
out = (
|
||||
out.unsqueeze(0).reshape(
|
||||
b, self.heads, out.shape[1],
|
||||
self.dim_head).permute(0, 2, 1,
|
||||
3).reshape(b, out.shape[1],
|
||||
self.heads * self.dim_head))
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class RelativePositionBias(nn.Module):
|
||||
|
||||
def __init__(self, heads=8, num_buckets=32, max_distance=128):
|
||||
@@ -242,7 +295,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
disable_self_attn=False):
|
||||
super().__init__()
|
||||
|
||||
attn_cls = CrossAttention
|
||||
attn_cls = MemoryEfficientCrossAttention
|
||||
self.disable_self_attn = disable_self_attn
|
||||
self.attn1 = attn_cls(
|
||||
query_dim=dim,
|
||||
@@ -1401,6 +1454,10 @@ class Img2VidSDUNet(nn.Module):
|
||||
module = checkpoint_wrapper(
|
||||
module) if self.use_checkpoint else module
|
||||
x = module(x, context)
|
||||
elif isinstance(module, MemoryEfficientCrossAttention):
|
||||
module = checkpoint_wrapper(
|
||||
module) if self.use_checkpoint else module
|
||||
x = module(x, context)
|
||||
elif isinstance(module, BasicTransformerBlock):
|
||||
module = checkpoint_wrapper(
|
||||
module) if self.use_checkpoint else module
|
||||
|
||||
Reference in New Issue
Block a user