SoggyKiwi commited on
Commit
f93fa3d
1 Parent(s): 534e187

add total variation loss + tuning changes

Browse files
Files changed (1) hide show
  1. app.py +19 -10
app.py CHANGED
@@ -11,14 +11,19 @@ model = ViTForImageClassification.from_pretrained('google/vit-large-patch32-384'
11
  model.to(device)
12
  model.eval()
13
 
 
 
 
 
 
 
 
 
 
 
14
  def process_image(input_image, learning_rate, iterations, n_targets, seed):
15
  if input_image is None:
16
  return None
17
-
18
- def get_encoder_activations(x):
19
- encoder_output = model.vit(x)
20
- final_activations = encoder_output.last_hidden_state[:,0,:]
21
- return final_activations
22
 
23
  image = input_image.convert('RGB')
24
  pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
@@ -36,8 +41,11 @@ def process_image(input_image, learning_rate, iterations, n_targets, seed):
36
 
37
  final_activations = get_encoder_activations(pixel_values)
38
  logits = model.classifier(final_activations[0])
39
- target_sum = logits[random_indices].sum()
40
- target_sum.backward()
 
 
 
41
 
42
  with torch.no_grad():
43
  pixel_values.data += learning_rate * pixel_values.grad.data
@@ -52,9 +60,10 @@ iface = gr.Interface(
52
  fn=process_image,
53
  inputs=[
54
  gr.Image(type="pil"),
55
- gr.Number(value=4.0, label="Learning Rate"),
56
- gr.Number(value=4, label="Iterations"),
57
- gr.Number(value=420, label="Seed"),
 
58
  gr.Number(value=50, minimum=1, maximum=1000, label="Number of Random Target Class Activations to Maximise"),
59
  ],
60
  outputs=[gr.Image(type="numpy", label="ViT-Dreamed Image")]
 
11
  model.to(device)
12
  model.eval()
13
 
14
+ def get_encoder_activations(x):
15
+ encoder_output = model.vit(x)
16
+ final_activations = encoder_output.last_hidden_state[:,0,:]
17
+ return final_activations
18
+
19
+ def total_variation_loss(img):
20
+ pixel_dif1 = img[:, :, 1:, :] - img[:, :, :-1, :]
21
+ pixel_dif2 = img[:, :, :, 1:] - img[:, :, :, :-1]
22
+ return (torch.sum(torch.abs(pixel_dif1)) + torch.sum(torch.abs(pixel_dif2)))
23
+
24
  def process_image(input_image, learning_rate, iterations, n_targets, seed):
25
  if input_image is None:
26
  return None
 
 
 
 
 
27
 
28
  image = input_image.convert('RGB')
29
  pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
 
41
 
42
  final_activations = get_encoder_activations(pixel_values)
43
  logits = model.classifier(final_activations[0])
44
+
45
+ original_loss = -logits[random_indices].sum()
46
+ tv_loss = total_variation_loss(pixel_values)
47
+ total_loss = original_loss + 0.00625 * tv_loss
48
+ total_loss.backward()
49
 
50
  with torch.no_grad():
51
  pixel_values.data += learning_rate * pixel_values.grad.data
 
60
  fn=process_image,
61
  inputs=[
62
  gr.Image(type="pil"),
63
+ gr.Number(value=10.0, minimum=0, label="Learning Rate"),
64
+ gr.Number(value=0.00625, label="Total Variation Loss"),
65
+ gr.Number(value=1, minimum=1, label="Iterations"),
66
+ gr.Number(value=420, minimum=0, label="Seed"),
67
  gr.Number(value=50, minimum=1, maximum=1000, label="Number of Random Target Class Activations to Maximise"),
68
  ],
69
  outputs=[gr.Image(type="numpy", label="ViT-Dreamed Image")]