File size: 5,530 Bytes
60d58ef
 
7913495
b6236ba
60d58ef
b6236ba
60d58ef
 
 
 
 
 
 
b6236ba
 
 
 
 
 
 
 
 
 
60d58ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6236ba
60d58ef
 
 
 
b6236ba
442c76a
b6236ba
 
 
 
 
 
 
442c76a
b6236ba
442c76a
60d58ef
b6236ba
 
60d58ef
b6236ba
442c76a
b6236ba
 
 
 
 
 
 
60d58ef
b6236ba
 
60d58ef
b6236ba
 
 
6e4fb57
b6236ba
 
6e4fb57
b6236ba
442c76a
 
b6236ba
 
 
442c76a
 
 
 
b6236ba
 
 
 
 
 
442c76a
 
 
 
 
 
 
 
 
 
 
b6236ba
442c76a
 
b6236ba
442c76a
 
 
 
 
 
 
 
 
 
 
60d58ef
 
 
 
 
b6236ba
 
 
60d58ef
 
b6236ba
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
os.system('pip install git+https://github.com/huggingface/transformers.git --upgrade')
os.system('pip install gradio --upgrade') 
os.system('pip freeze')

import os
import gradio as gr
from transformers import ViTFeatureExtractor, ViTModel
import torch
import torch.nn as nn
import torchvision
import matplotlib.pyplot as plt

import cv2
import numpy as np
from tqdm import tqdm
import glob
from PIL import Image

feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/dino-vits8", do_resize=True, padding=True)
model = ViTModel.from_pretrained("facebook/dino-vits8", add_pooling_layer=False)

def get_attention_maps(pixel_values, attentions, nh, out, img_path):
  threshold = 0.6
  w_featmap = pixel_values.shape[-2] // model.config.patch_size
  h_featmap = pixel_values.shape[-1] // model.config.patch_size

  # we keep only a certain percentage of the mass
  val, idx = torch.sort(attentions)
  val /= torch.sum(val, dim=1, keepdim=True)
  cumval = torch.cumsum(val, dim=1)
  th_attn = cumval > (1 - threshold)
  idx2 = torch.argsort(idx)
  for head in range(nh):
      th_attn[head] = th_attn[head][idx2[head]]
  th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()

  # interpolate
  th_attn = nn.functional.interpolate(th_attn.unsqueeze(0), scale_factor=model.config.patch_size, mode="nearest")[0].cpu().numpy()

  attentions = attentions.reshape(nh, w_featmap, h_featmap)
  attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=model.config.patch_size, mode="nearest")[0].cpu()
  attentions = attentions.detach().numpy()

  # sum all attentions
  fname = os.path.join(out, os.path.basename(img_path))
  plt.imsave(
      fname=fname,
      arr=sum(
          attentions[i] * 1 / attentions.shape[0]
          for i in range(attentions.shape[0])
      ),
      cmap="inferno",
      format="jpg",
  )
  return fname

def inference(inp: str, out: str):
  print(f"Generating attention images to {out}")

  # I had to process one at a time since colab was crashing...
  fnames = []
  for img_path in tqdm(sorted(glob.glob(os.path.join(inp, "*.jpg")))):
      with open(img_path, "rb") as f:
          img = Image.open(f)
          img = img.convert("RGB")
        
      # normalize channels
      pixel_values = feature_extractor(images=img, return_tensors="pt").pixel_values 

      # forward pass
      outputs = model(pixel_values, output_attentions=True, interpolate_pos_encoding=True)

      # get attentions of last layer
      attentions = outputs.attentions[-1] 
      nh = attentions.shape[1] # number of heads

      # we keep only the output patch attention
      attentions = attentions[0, :, 0, 1:].reshape(nh, -1)

      # sum and save attention maps 
      fnames.append(get_attention_maps(pixel_values, attentions, nh, out, img_path))
  return fnames


def func(video):
  clip = VideoFileClip(video)
  if clip.duration > 10:
      return 'trim.mp4'

  frames_folder = os.path.join("output", "frames")
  attention_folder = os.path.join("output", "attention")

  os.makedirs(frames_folder, exist_ok=True)
  os.makedirs(attention_folder, exist_ok=True)

  vid = VideoFileClip(inp)
  fps = vid.fps

  print(f"Video: {inp} ({fps} fps)")
  print(f"Extracting frames to {frames_folder}")

  vid.write_images_sequence(
      os.path.join(frames_folder, "frame-count%03d.jpg"),
  )

  output_frame_fnames = inference(frames_folder,attention_folder)

  new_clip = ImageSequenceClip(output_frame_fnames, fps=fps)
  new_clip.write_videofile("my_new_video.mp4")

  return "my_new_video.mp4"

title = "Interactive demo: DINO"
description = "Demo for Facebook AI's DINO, a new method for self-supervised training of Vision Transformers. Using this method, they are capable of segmenting objects within an image without having ever been trained to do so. This can be observed by displaying the self-attention of the heads from the last layer for the [CLS] token query. This demo uses a ViT-S/8 trained with DINO. To use it, simply upload an image or use the example image below. Results will show up in a few seconds."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2104.14294'>Emerging Properties in Self-Supervised Vision Transformers</a> | <a href='https://github.com/facebookresearch/dino'>Github Repo</a></p>"
iface = gr.Interface(fn=func, 
                     inputs=gr.inputs.Video(type=None), 
                     outputs="video",
                     title=title,
                     description=description,
                     article=article)


title = "Interactive demo: DINO"
description = "Demo for Facebook AI's DINO, a new method for self-supervised training of Vision Transformers. Using this method, they are capable of segmenting objects within an image without having ever been trained to do so. This can be observed by displaying the self-attention of the heads from the last layer for the [CLS] token query. This demo uses a ViT-S/8 trained with DINO. To use it, simply upload an image or use the example image below. Results will show up in a few seconds."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2104.14294'>Emerging Properties in Self-Supervised Vision Transformers</a> | <a href='https://github.com/facebookresearch/dino'>Github Repo</a></p>"
iface = gr.Interface(fn=func, 
                     inputs=gr.inputs.Video(type=None), 
                     outputs="video",
                     title=title,
                     description=description,
                     article=article)