jamino30 commited on
Commit
8d1740c
·
verified ·
1 Parent(s): 35add36

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. .gitignore +3 -1
  2. inference.py +10 -5
.gitignore CHANGED
@@ -167,4 +167,6 @@ cython_debug/
167
  # Gradio
168
  gradio_cached_examples
169
  flagged
170
- generated.jpg
 
 
 
167
  # Gradio
168
  gradio_cached_examples
169
  flagged
170
+ *.jpg
171
+ *.jpeg
172
+ *.png
inference.py CHANGED
@@ -1,10 +1,15 @@
1
- from tqdm import tqdm
2
-
3
  import torch
4
  import torch.optim as optim
5
  import torch.nn.functional as F
 
6
  from torchvision.transforms.functional import gaussian_blur
7
- from torchvision import models
 
 
 
 
 
 
8
 
9
  def _gram_matrix(feature):
10
  batch_size, n_feature_maps, height, width = feature.size()
@@ -57,9 +62,9 @@ def inference(
57
  if apply_to_background:
58
  segmentation_output = segmentation_model(content_image)['out']
59
  segmentation_mask = segmentation_output.argmax(dim=1)
60
-
61
  background_mask = (segmentation_mask == 0).float()
62
- foreground_mask = (segmentation_mask != 0).float()
 
63
 
64
  background_pixel_count = background_mask.sum().item()
65
  total_pixel_count = segmentation_mask.numel()
 
 
 
1
  import torch
2
  import torch.optim as optim
3
  import torch.nn.functional as F
4
+ import matplotlib.pyplot as plt
5
  from torchvision.transforms.functional import gaussian_blur
6
+
7
+ def save_mask(mask, title='mask'):
8
+ plt.imshow(mask.cpu().numpy()[0], cmap='gray')
9
+ plt.title(title)
10
+ plt.axis('off')
11
+ plt.savefig(f'{title}.png', bbox_inches='tight')
12
+ plt.close()
13
 
14
  def _gram_matrix(feature):
15
  batch_size, n_feature_maps, height, width = feature.size()
 
62
  if apply_to_background:
63
  segmentation_output = segmentation_model(content_image)['out']
64
  segmentation_mask = segmentation_output.argmax(dim=1)
 
65
  background_mask = (segmentation_mask == 0).float()
66
+ foreground_mask = 1 - background_mask
67
+ save_mask(background_mask, title='background-mask')
68
 
69
  background_pixel_count = background_mask.sum().item()
70
  total_pixel_count = segmentation_mask.numel()