dokster commited on
Commit
0d89394
1 Parent(s): cbcfd69

Upload 4 files

Browse files
Files changed (4) hide show
  1. inference.py +145 -0
  2. main.py +115 -0
  3. model.py +222 -0
  4. transformer_clip_gpt-007.pt +3 -0
inference.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import streamlit as st
4
+
5
+ from PIL import Image
6
+ from torch.nn import functional as nnf
7
+
8
+ # @st.cache_data
9
+ def generate2(
10
+ model,
11
+ tokenizer,
12
+ tokens=None,
13
+ prompt='',
14
+ embed=None,
15
+ entry_count=1,
16
+ entry_length=67,
17
+ top_p=0.98,
18
+ temperature=1,
19
+ stop_token='.',
20
+ ):
21
+
22
+ # model.eval()
23
+
24
+ generated_list = []
25
+ stop_token_index = tokenizer.encode(stop_token)[0]
26
+ filter_value = -float("Inf")
27
+ device = next(model.parameters()).device
28
+
29
+ with torch.no_grad():
30
+ for entry_idx in range(entry_count):
31
+ if not tokens:
32
+ tokens = torch.tensor(tokenizer.encode(prompt))
33
+ tokens = tokens.unsqueeze(0).to(device)
34
+
35
+ emb_tokens = model.gpt.transformer.wte(tokens)
36
+
37
+ if embed is not None:
38
+ generated = torch.cat((embed, emb_tokens), dim=1)
39
+ else:
40
+ generated = emb_tokens
41
+
42
+ for i in range(entry_length):
43
+ outputs = model.gpt(inputs_embeds=generated)
44
+ logits = outputs.logits
45
+ logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
46
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
47
+ cumulative_probs = torch.cumsum(nnf.softmax(sorted_logits, dim=-1), dim=-1)
48
+ sorted_indices_to_remove = cumulative_probs > top_p
49
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
50
+ ..., :-1
51
+ ].clone()
52
+ sorted_indices_to_remove[..., 0] = 0
53
+
54
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
55
+ logits[:, indices_to_remove] = filter_value
56
+
57
+ top_k = 2000
58
+ top_p = 0.98
59
+
60
+ next_token = torch.argmax(logits, -1).unsqueeze(0)
61
+ next_token_embed = model.gpt.transformer.wte(next_token)
62
+
63
+ if tokens is None:
64
+ tokens = next_token
65
+ else:
66
+ tokens = torch.cat((tokens, next_token), dim=1)
67
+
68
+ generated = torch.cat((generated, next_token_embed), dim=1)
69
+
70
+ if stop_token_index == next_token.item():
71
+ break
72
+
73
+ output_list = list(tokens.squeeze().cpu().numpy())
74
+
75
+ output_text = tokenizer.decode(output_list)
76
+ output_text = filter_ngrams(output_text)
77
+ generated_list.append(output_text)
78
+
79
+ return generated_list[0]
80
+
81
+ def filter_ngrams(output_text):
82
+ a_pos = output_text.find(' A:')
83
+ sec_a_pos = output_text.find(' A:', a_pos + 1)
84
+
85
+ return output_text[:sec_a_pos]
86
+
87
+ def image_grid(imgs, rows, cols):
88
+ assert len(imgs) == rows * cols
89
+
90
+ w, h = imgs[0].size
91
+ grid = Image.new('RGB', size=(cols * w, rows * h))
92
+ grid_w, grid_h = grid.size
93
+
94
+ for i, img in enumerate(imgs):
95
+ grid.paste(img, box=(i % cols * w, i // cols * h))
96
+
97
+ return grid
98
+
99
+ @st.cache_data
100
+ def read_video(path, transform=None, frames_num=9, window=30):
101
+ frames = []
102
+
103
+ cap = cv2.VideoCapture(path)
104
+
105
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
106
+ length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
107
+ N = length // (frames_num)
108
+ current_frame = 1
109
+
110
+ for i in range(length):
111
+ ret, frame = cap.read(current_frame)
112
+
113
+ if ret and i == current_frame and len(frames) < frames_num:
114
+ size = 193, 193
115
+ frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
116
+ frame.thumbnail(size, Image.ANTIALIAS)
117
+
118
+ frames.append(frame)
119
+ current_frame += N
120
+
121
+ cap.release()
122
+
123
+ return frames
124
+
125
+ # @st.cache_data
126
+ def get_caption(model, tokenizer, prefix, prefix_length, prompt=''):
127
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
128
+ prefix = prefix.to(device)
129
+
130
+ with torch.no_grad():
131
+ prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
132
+
133
+ if prompt:
134
+ generated_text_prefix = generate2(model, tokenizer, prompt=prompt, embed=prefix_embed)
135
+ else:
136
+ generated_text_prefix = generate2(model, tokenizer, embed=prefix_embed)
137
+
138
+ return generated_text_prefix.replace('\n', ' ')
139
+
140
+ # @st.cache_data
141
+ def get_ans(model, tokenizer, clip_emb, prefix_length, prompt):
142
+ output = get_caption(model, tokenizer, clip_emb, prefix_length, prompt=prompt)
143
+ ans = output[len(prompt):].strip()
144
+
145
+ return {'answer': ans}
main.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import torch
5
+ import clip
6
+ import tempfile
7
+
8
+ from tqdm import tqdm
9
+ from transformers import GPT2Tokenizer
10
+ from model import *
11
+ from inference import *
12
+
13
+ st.set_page_config(
14
+ page_title="Video Analysis AI",
15
+ page_icon="🕶️",
16
+ )
17
+
18
+ @st.cache_resource
19
+ def load_model():
20
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
21
+ clip_model, preprocess = clip.load("ViT-L/14@336px", device=device, jit=False)
22
+ tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3large_based_on_gpt2')
23
+
24
+ prefix_length = 50
25
+ model_path = 'transformer_clip_gpt-007.pt'
26
+ model = ClipCaptionModel('sberbank-ai/rugpt3small_based_on_gpt2', prefix_length=prefix_length)
27
+
28
+ model.load_state_dict(torch.load(model_path, map_location='cpu'))
29
+ model.to(device)
30
+ model.eval()
31
+
32
+ return model, clip_model, preprocess, tokenizer
33
+
34
+ def _max_width_():
35
+ max_width_str = f"max-width: 1400px;"
36
+ st.markdown(
37
+ f"""
38
+ <style>
39
+ .reportview-container .main .block-container{{
40
+ {max_width_str}
41
+ }}
42
+ </style>
43
+ """,
44
+ unsafe_allow_html=True,
45
+ )
46
+
47
+ _max_width_()
48
+
49
+
50
+ def main():
51
+ model, clip_model, preprocess, tokenizer = load_model()
52
+ prefix_length = 50
53
+
54
+ st.title("🦾 Video Analysis for Education")
55
+ st.header("")
56
+
57
+ with st.sidebar.expander("ℹ️ - About application", expanded=True):
58
+ st.write(
59
+ """
60
+ - Upload the video
61
+ - Make a question about the content of the video
62
+ - Recieve answer according your question prompt
63
+ """
64
+ )
65
+
66
+
67
+ uploaded_file = st.file_uploader("📌 Upload video: ", ['.mp4'])
68
+
69
+ # if play_video:
70
+ # video_bytes = uploaded_file.read()
71
+ # st.video(video_bytes)
72
+
73
+ st.write("---")
74
+
75
+ question = st.text_input("❔ Enter question prompt: ", "")
76
+
77
+
78
+ tfile = tempfile.NamedTemporaryFile(delete=False)
79
+ tfile.write(uploaded_file.read())
80
+
81
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
82
+ val_embeddings = []
83
+ val_captions = []
84
+ result = ''
85
+ text = f'Question: {question}? Answer:'
86
+
87
+ #read video -> get_ans
88
+ video = read_video(tfile.name, transform=None, frames_num=4)
89
+
90
+ if len(video) > 0:
91
+ i = image_grid(video, 2, 2)
92
+ image = preprocess(i).unsqueeze(0).to(device)
93
+
94
+ with torch.no_grad():
95
+ prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
96
+
97
+ val_embeddings.append(prefix)
98
+ val_captions.append(text)
99
+
100
+ answers = []
101
+
102
+ for i in tqdm(range(len(val_embeddings))):
103
+ emb = val_embeddings[i]
104
+ caption = val_captions[i]
105
+
106
+ ans = get_ans(model, tokenizer, emb, prefix_length, caption)
107
+ answers.append(ans['answer'])
108
+
109
+ result = answers[0].split(' A: ')[0]
110
+
111
+ res = st.text_input('✅ Answer to the question', result, disabled=False)
112
+
113
+
114
+ if __name__ == '__main__':
115
+ main()
model.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import clip
3
+ import numpy as np
4
+ import pandas as pd
5
+ import torch
6
+ import transformers
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from enum import Enum
11
+ from torch.nn import functional as nnf
12
+ from typing import Tuple, Optional, Union
13
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel
14
+
15
+ class MappingType(Enum):
16
+ MLP = 'mlp'
17
+ Transformer = 'transformer'
18
+
19
+ class MlpTransformer(nn.Module):
20
+ def __init__(self, in_dim, h_dim, out_d: Optional[int] = None, act=nnf.relu, dropout=0.):
21
+ super().__init__()
22
+ out_d = out_d if out_d is not None else in_dim
23
+ self.fc1 = nn.Linear(in_dim, h_dim)
24
+ self.act = act
25
+ self.fc2 = nn.Linear(h_dim, out_d)
26
+ self.dropout = nn.Dropout(dropout)
27
+
28
+ def forward(self, x):
29
+ x = self.fc1(x)
30
+ x = self.act(x)
31
+ x = self.dropout(x)
32
+ x = self.fc2(x)
33
+ x = self.dropout(x)
34
+
35
+ return x
36
+
37
+ class MLP(nn.Module):
38
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
39
+ return self.model(x)
40
+
41
+ def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
42
+ super(MLP, self).__init__()
43
+ layers = []
44
+ for i in range(len(sizes) - 1):
45
+ layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
46
+ if i < len(sizes) - 2:
47
+ layers.append(act())
48
+
49
+ self.model = nn.Sequential(*layers)
50
+
51
+
52
+ class MultiHeadAttention(nn.Module):
53
+ def __init__(self, dim_self, dim_ref, num_heads, bias=True, dropout=0.):
54
+ super().__init__()
55
+ self.num_heads = num_heads
56
+ head_dim = dim_self // num_heads
57
+ self.scale = head_dim ** -0.5
58
+ self.to_queries = nn.Linear(dim_self, dim_self, bias=bias)
59
+ self.to_keys_values = nn.Linear(dim_ref, dim_self * 2, bias=bias)
60
+ self.project = nn.Linear(dim_self, dim_self)
61
+ self.dropout = nn.Dropout(dropout)
62
+
63
+ def forward(self, x, y=None, mask=None):
64
+ y = y if y is not None else x
65
+ b, n, c = x.shape
66
+ _, m, d = y.shape
67
+
68
+ queries = self.to_queries(x).reshape(b, n, self.num_heads, c // self.num_heads)
69
+ keys_values = self.to_keys_values(y).reshape(b, m, 2, self.num_heads, c // self.num_heads)
70
+ keys, values = keys_values[:, :, 0], keys_values[:, :, 1]
71
+ attention = torch.einsum('bnhd,bmhd->bnmh', queries, keys) * self.scale
72
+
73
+ if mask is not None:
74
+ if mask.dim() == 2:
75
+ mask = mask.unsqueeze(1)
76
+ attention = attention.masked_fill(mask.unsqueeze(3), float("-inf"))
77
+
78
+ attention = attention.softmax(dim=2)
79
+
80
+ out = torch.einsum('bnmh,bmhd->bnhd', attention, values).reshape(b, n, c)
81
+ out = self.project(out)
82
+
83
+ return out, attention
84
+
85
+
86
+ class TransformerLayer(nn.Module):
87
+ def forward_with_attention(self, x, y=None, mask=None):
88
+ x_, attention = self.attn(self.norm1(x), y, mask)
89
+ x = x + x_
90
+ x = x + self.mlp(self.norm2(x))
91
+
92
+ return x, attention
93
+
94
+ def forward(self, x, y=None, mask=None):
95
+ x = x + self.attn(self.norm1(x), y, mask)[0]
96
+ x = x + self.mlp(self.norm2(x))
97
+
98
+ return x
99
+
100
+ def __init__(self, dim_self, dim_ref, num_heads, mlp_ratio=4., bias=False, dropout=0., act=nnf.relu,
101
+ norm_layer: nn.Module = nn.LayerNorm):
102
+ super().__init__()
103
+ self.norm1 = norm_layer(dim_self)
104
+ self.attn = MultiHeadAttention(dim_self, dim_ref, num_heads, bias=bias, dropout=dropout)
105
+ self.norm2 = norm_layer(dim_self)
106
+ self.mlp = MlpTransformer(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout)
107
+
108
+
109
+ class Transformer(nn.Module):
110
+ def forward_with_attention(self, x, y=None, mask=None):
111
+ attentions = []
112
+ for layer in self.layers:
113
+ x, att = layer.forward_with_attention(x, y, mask)
114
+ attentions.append(att)
115
+
116
+ return x, attentions
117
+
118
+ def forward(self, x, y=None, mask=None):
119
+ for i, layer in enumerate(self.layers):
120
+ if i % 2 == 0 and self.enc_dec:
121
+ x = layer(x, y)
122
+ elif self.enc_dec:
123
+ x = layer(x, x, mask)
124
+ else:
125
+ x = layer(x, y, mask)
126
+ return x
127
+
128
+ def __init__(self, dim_self: int, num_heads: int, num_layers: int, dim_ref: Optional[int] = None,
129
+ mlp_ratio: float = 2., act=nnf.relu, norm_layer: nn.Module = nn.LayerNorm, enc_dec: bool = False):
130
+ super(Transformer, self).__init__()
131
+ dim_ref = dim_ref if dim_ref is not None else dim_self
132
+ self.enc_dec = enc_dec
133
+
134
+ if enc_dec:
135
+ num_layers = num_layers * 2
136
+
137
+ layers = []
138
+
139
+ for i in range(num_layers):
140
+ if i % 2 == 0 and enc_dec:
141
+ layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
142
+ elif enc_dec:
143
+ layers.append(TransformerLayer(dim_self, dim_self, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
144
+ else:
145
+ layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
146
+
147
+ self.layers = nn.ModuleList(layers)
148
+
149
+
150
+ class TransformerMapper(nn.Module):
151
+ def forward(self, x):
152
+ x = self.linear(x).view(x.shape[0], self.clip_length, -1)
153
+ prefix = self.prefix_const.unsqueeze(0).expand(x.shape[0], *self.prefix_const.shape)
154
+ prefix = torch.cat((x, prefix), dim=1)
155
+ out = self.transformer(prefix)[:, self.clip_length:]
156
+
157
+ return out
158
+
159
+ def __init__(self, dim_clip: int, dim_embedding: int, prefix_length: int, clip_length: int, num_layers: int = 8):
160
+ super(TransformerMapper, self).__init__()
161
+ self.clip_length = clip_length
162
+ self.transformer = Transformer(dim_embedding, 8, num_layers)
163
+ self.linear = nn.Linear(dim_clip, clip_length * dim_embedding)
164
+ self.prefix_const = nn.Parameter(torch.randn(prefix_length, dim_embedding), requires_grad=True)
165
+
166
+ class MLP(nn.Module):
167
+ def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
168
+ super(MLP, self).__init__()
169
+ layers = []
170
+ for i in range(len(sizes) - 1):
171
+ layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
172
+ if i < len(sizes) - 2:
173
+ layers.append(act())
174
+ self.model = nn.Sequential(*layers)
175
+
176
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
177
+ return self.model(x)
178
+
179
+
180
+ class ClipCaptionModel(nn.Module):
181
+ def __init__(self, gpt, prefix_length: int, prefix_size: int = 768):
182
+ super(ClipCaptionModel, self).__init__()
183
+
184
+ self.prefix_length = prefix_length
185
+ clip_length = prefix_length
186
+ num_layers = 8
187
+
188
+ self.gpt = GPT2LMHeadModel.from_pretrained(gpt)
189
+ # self.gpt = freeze(self.gpt)
190
+ self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
191
+ self.clip_project = TransformerMapper(prefix_size, self.gpt_embedding_size, prefix_length,
192
+ clip_length, num_layers)
193
+
194
+ def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor:
195
+ return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
196
+
197
+ def forward(self, tokens: torch.Tensor, prefix: torch.Tensor,
198
+ mask: Optional[torch.Tensor] = None,
199
+ labels: Optional[torch.Tensor] = None):
200
+ embedding_text = self.gpt.transformer.wte(tokens)
201
+ prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
202
+
203
+ embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
204
+
205
+ if labels is not None:
206
+ dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
207
+ labels = torch.cat((dummy_token, tokens), dim=1)
208
+
209
+ out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
210
+
211
+ return out
212
+
213
+
214
+ class ClipCaptionPrefix(ClipCaptionModel):
215
+ def parameters(self, recurse: bool = True):
216
+ return self.clip_project.parameters()
217
+
218
+ def train(self, mode: bool = True):
219
+ super(ClipCaptionPrefix, self).train(mode)
220
+ self.gpt.eval()
221
+
222
+ return self
transformer_clip_gpt-007.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b1ea7b66ea0f3e84102e9d8d1fbf744ecad61ba6653af4702d0ca668c888bfed
3
+ size 770490716