jamino30 commited on
Commit
57ebd4f
1 Parent(s): 962b2f7

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +19 -17
app.py CHANGED
@@ -34,6 +34,24 @@ optimal_settings = {
34
  'Watercolor': (10, False),
35
  }
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  @spaces.GPU(duration=20)
38
  def inference(content_image, style_image, style_strength, output_quality, progress=gr.Progress(track_tqdm=True)):
39
  yield None
@@ -68,23 +86,7 @@ def inference(content_image, style_image, style_strength, output_quality, progre
68
  optimizer.zero_grad()
69
 
70
  generated_features = model(generated_img)
71
-
72
- content_loss = 0
73
- style_loss = 0
74
-
75
- for generated_feature, content_feature, style_feature in zip(generated_features, content_features, style_features):
76
- batch_size, n_feature_maps, height, width = generated_feature.size()
77
-
78
- content_loss += (torch.mean((generated_feature - content_feature) ** 2))
79
-
80
- G = torch.mm((generated_feature.view(batch_size * n_feature_maps, height * width)), (generated_feature.view(batch_size * n_feature_maps, height * width)).t())
81
- A = torch.mm((style_feature.view(batch_size * n_feature_maps, height * width)), (style_feature.view(batch_size * n_feature_maps, height * width)).t())
82
-
83
- E_l = ((G - A) ** 2)
84
- w_l = 1/5
85
- style_loss += torch.mean(w_l * E_l)
86
-
87
- total_loss = alpha * content_loss + beta * style_loss
88
 
89
  total_loss.backward()
90
  optimizer.step()
 
34
  'Watercolor': (10, False),
35
  }
36
 
37
+ def compute_loss(generated_features, content_features, style_features, alpha, beta):
38
+ content_loss = 0
39
+ style_loss = 0
40
+
41
+ for generated_feature, content_feature, style_feature in zip(generated_features, content_features, style_features):
42
+ batch_size, n_feature_maps, height, width = generated_feature.size()
43
+
44
+ content_loss += (torch.mean((generated_feature - content_feature) ** 2))
45
+
46
+ G = torch.mm((generated_feature.view(batch_size * n_feature_maps, height * width)), (generated_feature.view(batch_size * n_feature_maps, height * width)).t())
47
+ A = torch.mm((style_feature.view(batch_size * n_feature_maps, height * width)), (style_feature.view(batch_size * n_feature_maps, height * width)).t())
48
+
49
+ E_l = ((G - A) ** 2)
50
+ w_l = 1/5
51
+ style_loss += torch.mean(w_l * E_l)
52
+
53
+ return alpha * content_loss + beta * style_loss
54
+
55
  @spaces.GPU(duration=20)
56
  def inference(content_image, style_image, style_strength, output_quality, progress=gr.Progress(track_tqdm=True)):
57
  yield None
 
86
  optimizer.zero_grad()
87
 
88
  generated_features = model(generated_img)
89
+ total_loss = compute_loss(generated_features, content_features, style_features, alpha, beta)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  total_loss.backward()
92
  optimizer.step()