File size: 3,926 Bytes
9d3c2b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import math

import torch


def exist(item):
    return item is not None

def freeze(model):
    for p in model.parameters():
        p.requires_grad = False
    return model
    
def get_freqs(dim, max_period=10000.):
    freqs = torch.exp(
        -math.log(max_period) * torch.arange(start=0, end=dim, dtype=torch.float32) / dim
    )
    return freqs


def get_group_sizes(shape, num_groups):
    return [*map(lambda x: x[0] // x[1], zip(shape, num_groups))]


def rescale_group_rope(num_groups, scale_factor, rescale_factor):
    num_groups = [*map(lambda x: int(x[0] / x[1]), zip(num_groups, rescale_factor))]
    scale_factor = [*map(lambda x: x[0] / x[1], zip(scale_factor, rescale_factor))]
    return num_groups, scale_factor


def cat_interleave(visual_query_key_value, text_query_key_value, visual_cu_seqlens, text_cu_seqlens):
    query_key_value = []
    for local_visual_query_key_value, local_text_query_key_value in zip(
        torch.split(visual_query_key_value, torch.diff(visual_cu_seqlens).tolist(), dim=1),
        torch.split(text_query_key_value, torch.diff(text_cu_seqlens).tolist(), dim=1)
    ):
        query_key_value += [local_visual_query_key_value, local_text_query_key_value]
    query_key_value = torch.cat(query_key_value, dim=1)
    return query_key_value


def split_interleave(out, cu_seqlens, split_len):
    visual_out, text_out = [], []
    for local_out in torch.split(out, torch.diff(cu_seqlens).tolist(), dim=1):
        visual_out.append(local_out[:, :-split_len])
        text_out.append(local_out[0, -split_len:])
    visual_out, text_out = torch.cat(visual_out, dim=1), torch.cat(text_out, dim=0)
    return visual_out, text_out


def local_patching(x, shape, group_size, dim=0):
    duration, height, width = shape
    g1, g2, g3 = group_size
    x = x.reshape(*x.shape[:dim], duration//g1, g1, height//g2, g2, width//g3, g3, *x.shape[dim+3:])
    x = x.permute(
        *range(len(x.shape[:dim])),
        dim, dim+2, dim+4, dim+1, dim+3, dim+5, 
        *range(dim+6, len(x.shape))
    )
    x = x.flatten(dim, dim+2).flatten(dim+1, dim+3)
    return x


def local_merge(x, shape, group_size, dim=0):
    duration, height, width = shape
    g1, g2, g3 = group_size
    x = x.reshape(*x.shape[:dim], duration//g1, height//g2, width//g3, g1, g2, g3, *x.shape[dim+2:])
    x = x.permute(
        *range(len(x.shape[:dim])),
        dim, dim+3, dim+1, dim+4, dim+2, dim+5, 
        *range(dim+6, len(x.shape))
    )
    x = x.flatten(dim, dim+1).flatten(dim+1, dim+2).flatten(dim+2, dim+3)
    return x


def global_patching(x, shape, group_size, dim=0):
    latent_group_size = [axis // axis_group_size for axis, axis_group_size in zip(shape, group_size)]
    x = local_patching(x, shape, latent_group_size, dim)
    x = x.transpose(dim, dim+1)
    return x


def global_merge(x, shape, group_size, dim=0):
    latent_group_size = [axis // axis_group_size for axis, axis_group_size in zip(shape, group_size)]
    x = x.transpose(dim, dim+1)
    x = local_merge(x, shape, latent_group_size, dim)
    return x


def to_1dimension(visual_embed, visual_cu_seqlens, visual_shape, num_groups, attention_type):
    group_size = get_group_sizes(visual_shape, num_groups)
    if attention_type == 'local':
        visual_embed = local_patching(visual_embed, visual_shape, group_size, dim=0)
    if attention_type == 'global':
        visual_embed = global_patching(visual_embed, visual_shape, group_size, dim=0)
    visual_cu_seqlens = visual_cu_seqlens * math.prod(group_size[1:])
    return visual_embed, visual_cu_seqlens


def to_3dimension(visual_embed, visual_shape, num_groups, attention_type):
    group_size = get_group_sizes(visual_shape, num_groups)
    if attention_type == 'local':
        x = local_merge(visual_embed, visual_shape, group_size, dim=0)
    if attention_type == 'global':
        x = global_merge(visual_embed, visual_shape, group_size, dim=0)
    return x