Spaces:
Running
on
Zero
Running
on
Zero
Upload folder using huggingface_hub
Browse files- .gitignore +3 -1
- inference.py +10 -5
.gitignore
CHANGED
@@ -167,4 +167,6 @@ cython_debug/
|
|
167 |
# Gradio
|
168 |
gradio_cached_examples
|
169 |
flagged
|
170 |
-
|
|
|
|
|
|
167 |
# Gradio
|
168 |
gradio_cached_examples
|
169 |
flagged
|
170 |
+
*.jpg
|
171 |
+
*.jpeg
|
172 |
+
*.png
|
inference.py
CHANGED
@@ -1,10 +1,15 @@
|
|
1 |
-
from tqdm import tqdm
|
2 |
-
|
3 |
import torch
|
4 |
import torch.optim as optim
|
5 |
import torch.nn.functional as F
|
|
|
6 |
from torchvision.transforms.functional import gaussian_blur
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
def _gram_matrix(feature):
|
10 |
batch_size, n_feature_maps, height, width = feature.size()
|
@@ -57,9 +62,9 @@ def inference(
|
|
57 |
if apply_to_background:
|
58 |
segmentation_output = segmentation_model(content_image)['out']
|
59 |
segmentation_mask = segmentation_output.argmax(dim=1)
|
60 |
-
|
61 |
background_mask = (segmentation_mask == 0).float()
|
62 |
-
foreground_mask =
|
|
|
63 |
|
64 |
background_pixel_count = background_mask.sum().item()
|
65 |
total_pixel_count = segmentation_mask.numel()
|
|
|
|
|
|
|
1 |
import torch
|
2 |
import torch.optim as optim
|
3 |
import torch.nn.functional as F
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
from torchvision.transforms.functional import gaussian_blur
|
6 |
+
|
7 |
+
def save_mask(mask, title='mask'):
|
8 |
+
plt.imshow(mask.cpu().numpy()[0], cmap='gray')
|
9 |
+
plt.title(title)
|
10 |
+
plt.axis('off')
|
11 |
+
plt.savefig(f'{title}.png', bbox_inches='tight')
|
12 |
+
plt.close()
|
13 |
|
14 |
def _gram_matrix(feature):
|
15 |
batch_size, n_feature_maps, height, width = feature.size()
|
|
|
62 |
if apply_to_background:
|
63 |
segmentation_output = segmentation_model(content_image)['out']
|
64 |
segmentation_mask = segmentation_output.argmax(dim=1)
|
|
|
65 |
background_mask = (segmentation_mask == 0).float()
|
66 |
+
foreground_mask = 1 - background_mask
|
67 |
+
save_mask(background_mask, title='background-mask')
|
68 |
|
69 |
background_pixel_count = background_mask.sum().item()
|
70 |
total_pixel_count = segmentation_mask.numel()
|