File size: 3,526 Bytes
cf2f35c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch
from diffusers import ModelMixin
from einops import rearrange
from torch import nn


class Motion2bucketModel(ModelMixin):
    def __init__(self, window_size=5, blocks=12, channels=1024, clip_channels=1280, intermediate_dim=512, output_dim=768, context_tokens=32, clip_token_num=1, final_output_dim=5120):
        super().__init__()
        self.window_size = window_size
        self.clip_token_num = clip_token_num
        self.blocks = blocks
        self.channels = channels
        # self.input_dim = (window_size * blocks * channels + clip_channels*clip_token_num)
        self.input_dim = (window_size * channels + clip_channels * clip_token_num)
        self.intermediate_dim = intermediate_dim
        self.context_tokens = context_tokens
        self.output_dim = output_dim

        # define multiple linear layers
        self.proj1 = nn.Linear(self.input_dim, intermediate_dim)
        self.proj2 = nn.Linear(intermediate_dim, intermediate_dim)
        self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim)
        self.act = nn.SiLU()


        self.final_proj = torch.nn.Linear(output_dim, final_output_dim)
        self.final_norm = torch.nn.LayerNorm(final_output_dim)

        nn.init.constant_(self.final_proj.weight, 0)
        if self.final_proj.bias is not None:
            nn.init.constant_(self.final_proj.bias, 0)

    def forward(self, audio_embeds, clip_embeds):
        """

        Defines the forward pass for the AudioProjModel.



        Parameters:

            audio_embeds (torch.Tensor): The input audio embeddings with shape (batch_size, video_length, blocks, channels).



        Returns:

            context_tokens (torch.Tensor): The output context tokens with shape (batch_size, video_length, context_tokens, output_dim).

        """
        # merge
        video_length = audio_embeds.shape[1]
        # audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
        audio_embeds = rearrange(audio_embeds, "bz f w c -> (bz f) w c")
        clip_embeds = clip_embeds.repeat(audio_embeds.size()[0]//clip_embeds.size()[0], 1, 1)
        clip_embeds = rearrange(clip_embeds, "b n d -> b (n d)")
        # batch_size, window_size, blocks, channels = audio_embeds.shape
        # audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
        batch_size, window_size, channels = audio_embeds.shape
        audio_embeds = audio_embeds.view(batch_size, window_size * channels)
        audio_embeds = torch.cat([audio_embeds, clip_embeds], dim=-1)

        audio_embeds = self.act(self.proj1(audio_embeds))
        audio_embeds = self.act(self.proj2(audio_embeds))

        context_tokens = self.proj3(audio_embeds).reshape(
            batch_size, self.context_tokens, self.output_dim
        )

        # context_tokens = self.norm(context_tokens)
        context_tokens = rearrange(
            context_tokens, "(bz f) m c -> bz f m c", f=video_length
        )

        context_tokens = self.act(context_tokens)
        context_tokens = self.final_norm(self.final_proj(context_tokens))

        return context_tokens


if __name__ == '__main__':
    model = Motion2bucketModel(window_size=5)
    # audio_features = torch.randn(1, 81, 5, 12, 768)
    audio_features = torch.randn(1, 81, 5, 1024)
    clip_image_features = torch.randn(1, 1, 1280)

    out = model(audio_features, clip_image_features).mean(dim=2).mean(dim=1)
    print(out.size())