jamino30 commited on
Commit
28ac920
1 Parent(s): ecf0440

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. app.py +3 -1
  2. inference.py +31 -54
  3. u2net/inference.py +11 -7
app.py CHANGED
@@ -30,7 +30,8 @@ def load_model_without_module(model, model_path):
30
  name = k[7:] if k.startswith('module.') else k
31
  new_state_dict[name] = v
32
  model.load_state_dict(new_state_dict)
33
-
 
34
  model = VGG_19().to(device).eval()
35
  for param in model.parameters():
36
  param.requires_grad = False
@@ -43,6 +44,7 @@ style_options = {' '.join(style_file.split('.')[0].split('_')): f'./style_images
43
  lrs = np.logspace(np.log10(0.001), np.log10(0.1), 10).tolist()
44
  img_size = 512
45
 
 
46
  cached_style_features = {}
47
  for style_name, style_img_path in style_options.items():
48
  style_img = preprocess_img_from_path(style_img_path, img_size)[0].to(device)
 
30
  name = k[7:] if k.startswith('module.') else k
31
  new_state_dict[name] = v
32
  model.load_state_dict(new_state_dict)
33
+
34
+ # load models
35
  model = VGG_19().to(device).eval()
36
  for param in model.parameters():
37
  param.requires_grad = False
 
44
  lrs = np.logspace(np.log10(0.001), np.log10(0.1), 10).tolist()
45
  img_size = 512
46
 
47
+ # store style(s) features
48
  cached_style_features = {}
49
  for style_name, style_img_path in style_options.items():
50
  style_img = preprocess_img_from_path(style_img_path, img_size)[0].to(device)
inference.py CHANGED
@@ -1,37 +1,24 @@
1
- import os
2
- from tqdm import tqdm
3
-
4
  import torch
5
  import torch.optim as optim
6
  import torch.nn.functional as F
7
  from torchvision.transforms.functional import gaussian_blur
 
8
 
9
- def _gram_matrix(feature):
10
- batch_size, n_feature_maps, height, width = feature.size()
11
- new_feature = feature.view(batch_size * n_feature_maps, height * width)
12
- return torch.mm(new_feature, new_feature.t())
13
 
14
- def _compute_loss(generated_features, content_features, style_features, resized_bg_masks, alpha, beta):
15
- content_loss = 0
16
- style_loss = 0
17
- w_l = 1 / len(generated_features)
18
-
19
- for i, (gf, cf, sf) in enumerate(zip(generated_features, content_features, style_features)):
20
- content_loss += F.mse_loss(gf, cf)
21
-
22
- if resized_bg_masks:
23
- blurred_bg_mask = gaussian_blur(resized_bg_masks[i], kernel_size=5)
24
- masked_gf = gf * blurred_bg_mask
25
- masked_sf = sf * blurred_bg_mask
26
- G = _gram_matrix(masked_gf)
27
- A = _gram_matrix(masked_sf)
28
- else:
29
- G = _gram_matrix(gf)
30
- A = _gram_matrix(sf)
31
- style_loss += w_l * F.mse_loss(G, A)
32
-
33
- total_loss = alpha * content_loss + beta * style_loss
34
- return content_loss, style_loss, total_loss
35
 
36
  def inference(
37
  *,
@@ -41,7 +28,7 @@ def inference(
41
  content_image_norm,
42
  style_features,
43
  apply_to_background,
44
- lr,
45
  iterations=101,
46
  optim_caller=optim.AdamW,
47
  alpha=1,
@@ -49,43 +36,33 @@ def inference(
49
  ):
50
  generated_image = content_image.clone().requires_grad_(True)
51
  optimizer = optim_caller([generated_image], lr=lr)
52
- min_losses = [float('inf')] * iterations
53
 
54
  with torch.no_grad():
55
  content_features = model(content_image)
56
-
57
- resized_bg_masks = []
58
  if apply_to_background:
59
- segmentation_output = sod_model(content_image_norm)[0]
60
- segmentation_output = torch.sigmoid(segmentation_output)
61
- segmentation_mask = (segmentation_output > 0.7).float()
62
- background_mask = (segmentation_mask == 0).float()
63
- foreground_mask = 1 - background_mask
64
-
65
- for cf in content_features:
66
- _, _, h_i, w_i = cf.shape
67
- bg_mask = F.interpolate(background_mask.unsqueeze(1), size=(h_i, w_i), mode='bilinear', align_corners=False)
68
- resized_bg_masks.append(bg_mask)
69
 
70
- def closure(iter):
71
  optimizer.zero_grad()
72
  generated_features = model(generated_image)
73
- content_loss, style_loss, total_loss = _compute_loss(
74
- generated_features, content_features, style_features, resized_bg_masks, alpha, beta
75
  )
76
  total_loss.backward()
77
-
78
- # log loss
79
- min_losses[iter] = min(min_losses[iter], total_loss.item())
80
-
81
  return total_loss
82
 
83
- for iter in tqdm(range(iterations)):
84
- optimizer.step(lambda: closure(iter))
85
-
86
  if apply_to_background:
87
  with torch.no_grad():
88
- foreground_mask_resized = F.interpolate(foreground_mask.unsqueeze(1), size=generated_image.shape[2:], mode='nearest')
89
- generated_image.data = generated_image.data * (1 - foreground_mask_resized) + content_image.data * foreground_mask_resized
90
 
91
  return generated_image
 
 
 
 
1
  import torch
2
  import torch.optim as optim
3
  import torch.nn.functional as F
4
  from torchvision.transforms.functional import gaussian_blur
5
+ from tqdm import tqdm
6
 
7
+ def gram_matrix(feature):
8
+ b, c, h, w = feature.size()
9
+ feature = feature.view(b * c, h * w)
10
+ return feature @ feature.t()
11
 
12
+ def compute_loss(generated, content, style, bg_masks, alpha, beta):
13
+ content_loss = sum(F.mse_loss(gf, cf) for gf, cf in zip(generated, content))
14
+ style_loss = sum(
15
+ F.mse_loss(
16
+ gram_matrix(gf * bg) if bg is not None else gram_matrix(gf),
17
+ gram_matrix(sf * bg) if bg is not None else gram_matrix(sf),
18
+ ) / len(generated)
19
+ for gf, sf, bg in zip(generated, style, bg_masks or [None] * len(generated))
20
+ )
21
+ return alpha * content_loss, beta * style_loss, alpha * content_loss + beta * style_loss
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  def inference(
24
  *,
 
28
  content_image_norm,
29
  style_features,
30
  apply_to_background,
31
+ lr=5e-2,
32
  iterations=101,
33
  optim_caller=optim.AdamW,
34
  alpha=1,
 
36
  ):
37
  generated_image = content_image.clone().requires_grad_(True)
38
  optimizer = optim_caller([generated_image], lr=lr)
 
39
 
40
  with torch.no_grad():
41
  content_features = model(content_image)
42
+ bg_masks = None
43
+
44
  if apply_to_background:
45
+ seg_output = torch.sigmoid(sod_model(content_image_norm)[0])
46
+ bg_mask = (seg_output <= 0.7).float()
47
+ bg_masks = [
48
+ F.interpolate(bg_mask.unsqueeze(1), size=cf.shape[2:], mode='bilinear', align_corners=False)
49
+ for cf in content_features
50
+ ]
 
 
 
 
51
 
52
+ def closure():
53
  optimizer.zero_grad()
54
  generated_features = model(generated_image)
55
+ content_loss, style_loss, total_loss = compute_loss(
56
+ generated_features, content_features, style_features, bg_masks, alpha, beta
57
  )
58
  total_loss.backward()
 
 
 
 
59
  return total_loss
60
 
61
+ for _ in tqdm(range(iterations)):
62
+ optimizer.step(closure)
 
63
  if apply_to_background:
64
  with torch.no_grad():
65
+ fg_mask = F.interpolate(1 - bg_masks[0], size=generated_image.shape[2:], mode='nearest')
66
+ generated_image.data.mul_(1 - fg_mask).add_(content_image.data * fg_mask)
67
 
68
  return generated_image
u2net/inference.py CHANGED
@@ -9,19 +9,22 @@ from matplotlib.gridspec import GridSpec
9
 
10
  from model import U2Net
11
 
12
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 
 
13
 
14
  def preprocess_image(image_path):
15
  img = Image.open(image_path).convert('RGB')
16
  preprocess = transforms.Compose([
17
- transforms.Resize((512, 512)),
18
  transforms.ToTensor(),
19
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
20
  ])
21
  img = preprocess(img).unsqueeze(0).to(device)
22
  return img
23
 
24
- def run_inference(model, image_path, threshold=0.5):
25
  input_img = preprocess_image(image_path)
26
  with torch.no_grad():
27
  d1, *_ = model(input_img)
@@ -47,15 +50,16 @@ def overlay_segmentation(original_image, binary_mask, alpha=0.5):
47
 
48
  if __name__ == '__main__':
49
  # ---
50
- model_path = 'results/u2net-duts-msra.safetensors'
51
- image_path = 'images/ladies.jpg'
 
52
  # ---
53
  model = U2Net().to(device)
54
  model = nn.DataParallel(model)
55
  model.load_state_dict(load_file(model_path, device=device.type))
56
  model.eval()
57
 
58
- mask = run_inference(model, image_path, threshold=None)
59
  mask_with_threshold = run_inference(model, image_path, threshold=0.7)
60
 
61
  fig = plt.figure(figsize=(10, 10))
@@ -74,4 +78,4 @@ if __name__ == '__main__':
74
  ax.axis('off')
75
 
76
  plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
77
- plt.savefig('inference-output.jpg', format='jpg', bbox_inches='tight', pad_inches=0)
 
9
 
10
  from model import U2Net
11
 
12
+ if torch.cuda.is_available(): device = 'cuda'
13
+ elif torch.backends.mps.is_available(): device = 'mps'
14
+ else: device = 'cpu'
15
+ device = torch.device(device)
16
 
17
  def preprocess_image(image_path):
18
  img = Image.open(image_path).convert('RGB')
19
  preprocess = transforms.Compose([
20
+ transforms.Resize((512, 512), interpolation=transforms.InterpolationMode.BILINEAR),
21
  transforms.ToTensor(),
22
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
23
  ])
24
  img = preprocess(img).unsqueeze(0).to(device)
25
  return img
26
 
27
+ def run_inference(model, image_path, threshold=None):
28
  input_img = preprocess_image(image_path)
29
  with torch.no_grad():
30
  d1, *_ = model(input_img)
 
50
 
51
  if __name__ == '__main__':
52
  # ---
53
+ model_path = '../testing/u2net-duts-msra.safetensors'
54
+ filename = input('Filename: ')
55
+ image_path = f'../content_images/{filename}'
56
  # ---
57
  model = U2Net().to(device)
58
  model = nn.DataParallel(model)
59
  model.load_state_dict(load_file(model_path, device=device.type))
60
  model.eval()
61
 
62
+ mask = run_inference(model, image_path)
63
  mask_with_threshold = run_inference(model, image_path, threshold=0.7)
64
 
65
  fig = plt.figure(figsize=(10, 10))
 
78
  ax.axis('off')
79
 
80
  plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
81
+ plt.savefig('../testing/inference-output.jpg', format='jpg', bbox_inches='tight', pad_inches=0)