Spaces:
Running
on
Zero
Running
on
Zero
Upload folder using huggingface_hub
Browse files
app.py
CHANGED
@@ -34,6 +34,24 @@ optimal_settings = {
|
|
34 |
'Watercolor': (10, False),
|
35 |
}
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
@spaces.GPU(duration=20)
|
38 |
def inference(content_image, style_image, style_strength, output_quality, progress=gr.Progress(track_tqdm=True)):
|
39 |
yield None
|
@@ -68,23 +86,7 @@ def inference(content_image, style_image, style_strength, output_quality, progre
|
|
68 |
optimizer.zero_grad()
|
69 |
|
70 |
generated_features = model(generated_img)
|
71 |
-
|
72 |
-
content_loss = 0
|
73 |
-
style_loss = 0
|
74 |
-
|
75 |
-
for generated_feature, content_feature, style_feature in zip(generated_features, content_features, style_features):
|
76 |
-
batch_size, n_feature_maps, height, width = generated_feature.size()
|
77 |
-
|
78 |
-
content_loss += (torch.mean((generated_feature - content_feature) ** 2))
|
79 |
-
|
80 |
-
G = torch.mm((generated_feature.view(batch_size * n_feature_maps, height * width)), (generated_feature.view(batch_size * n_feature_maps, height * width)).t())
|
81 |
-
A = torch.mm((style_feature.view(batch_size * n_feature_maps, height * width)), (style_feature.view(batch_size * n_feature_maps, height * width)).t())
|
82 |
-
|
83 |
-
E_l = ((G - A) ** 2)
|
84 |
-
w_l = 1/5
|
85 |
-
style_loss += torch.mean(w_l * E_l)
|
86 |
-
|
87 |
-
total_loss = alpha * content_loss + beta * style_loss
|
88 |
|
89 |
total_loss.backward()
|
90 |
optimizer.step()
|
|
|
34 |
'Watercolor': (10, False),
|
35 |
}
|
36 |
|
37 |
+
def compute_loss(generated_features, content_features, style_features, alpha, beta):
|
38 |
+
content_loss = 0
|
39 |
+
style_loss = 0
|
40 |
+
|
41 |
+
for generated_feature, content_feature, style_feature in zip(generated_features, content_features, style_features):
|
42 |
+
batch_size, n_feature_maps, height, width = generated_feature.size()
|
43 |
+
|
44 |
+
content_loss += (torch.mean((generated_feature - content_feature) ** 2))
|
45 |
+
|
46 |
+
G = torch.mm((generated_feature.view(batch_size * n_feature_maps, height * width)), (generated_feature.view(batch_size * n_feature_maps, height * width)).t())
|
47 |
+
A = torch.mm((style_feature.view(batch_size * n_feature_maps, height * width)), (style_feature.view(batch_size * n_feature_maps, height * width)).t())
|
48 |
+
|
49 |
+
E_l = ((G - A) ** 2)
|
50 |
+
w_l = 1/5
|
51 |
+
style_loss += torch.mean(w_l * E_l)
|
52 |
+
|
53 |
+
return alpha * content_loss + beta * style_loss
|
54 |
+
|
55 |
@spaces.GPU(duration=20)
|
56 |
def inference(content_image, style_image, style_strength, output_quality, progress=gr.Progress(track_tqdm=True)):
|
57 |
yield None
|
|
|
86 |
optimizer.zero_grad()
|
87 |
|
88 |
generated_features = model(generated_img)
|
89 |
+
total_loss = compute_loss(generated_features, content_features, style_features, alpha, beta)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
total_loss.backward()
|
92 |
optimizer.step()
|