Sophie98 commited on
Commit
baaaa83
1 Parent(s): ab7b996

find next error

Browse files
Files changed (2) hide show
  1. app.py +4 -0
  2. styleTransfer.py +1 -6
app.py CHANGED
@@ -96,11 +96,15 @@ def style_sofa(input_img: np.ndarray, style_img: np.ndarray):
96
  resized_style = resize_style(style_img)
97
  resized_style.save('resized_style.jpg')
98
  # generate mask for image
 
99
  mask = get_mask(resized_img)
100
  mask.save('mask.jpg')
 
 
101
  styled_sofa = create_styledSofa(resized_img,resized_style)
102
  styled_sofa.save('styled_sofa.jpg')
103
  # postprocess the final image
 
104
  new_sofa = replace_sofa(resized_img,mask,styled_sofa)
105
  new_sofa = new_sofa.crop(box)
106
  return new_sofa
96
  resized_style = resize_style(style_img)
97
  resized_style.save('resized_style.jpg')
98
  # generate mask for image
99
+ print('generating mask...')
100
  mask = get_mask(resized_img)
101
  mask.save('mask.jpg')
102
+ # Created a styled sofa
103
+ print('Styling sofa...')
104
  styled_sofa = create_styledSofa(resized_img,resized_style)
105
  styled_sofa.save('styled_sofa.jpg')
106
  # postprocess the final image
107
+ print('Replacing sofa...')
108
  new_sofa = replace_sofa(resized_img,mask,styled_sofa)
109
  new_sofa = new_sofa.crop(box)
110
  return new_sofa
styleTransfer.py CHANGED
@@ -20,7 +20,6 @@ embedding_path = 'embedding_iter_160000.pth'
20
 
21
  def style_transform(h,w):
22
  k = (h,w)
23
- size = int(np.max(k))
24
  transform_list = []
25
  transform_list.append(transforms.CenterCrop((h,w)))
26
  transform_list.append(transforms.ToTensor())
@@ -37,8 +36,6 @@ def content_transform():
37
  content_size=640
38
  style_size=640
39
 
40
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
-
42
  vgg = StyTR.vgg
43
  vgg.load_state_dict(torch.load(vgg_path))
44
  vgg = nn.Sequential(*list(vgg.children())[:44])
@@ -70,14 +67,12 @@ content_tf = content_transform()
70
  style_tf = style_transform(style_size,style_size)
71
 
72
  def StyleTransformer(content_img: Image, style_img: Image):
73
-
74
  network.to(device)
75
-
76
  content = content_tf(content_img.convert("RGB"))
77
  style = style_tf(style_img.convert("RGB"))
78
  style = style.to(device).unsqueeze(0)
79
  content = content.to(device).unsqueeze(0)
80
-
81
  with torch.no_grad():
82
  output= network(content,style)
83
  output = output[0].cpu().squeeze()
20
 
21
  def style_transform(h,w):
22
  k = (h,w)
 
23
  transform_list = []
24
  transform_list.append(transforms.CenterCrop((h,w)))
25
  transform_list.append(transforms.ToTensor())
36
  content_size=640
37
  style_size=640
38
 
 
 
39
  vgg = StyTR.vgg
40
  vgg.load_state_dict(torch.load(vgg_path))
41
  vgg = nn.Sequential(*list(vgg.children())[:44])
67
  style_tf = style_transform(style_size,style_size)
68
 
69
  def StyleTransformer(content_img: Image, style_img: Image):
70
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71
  network.to(device)
 
72
  content = content_tf(content_img.convert("RGB"))
73
  style = style_tf(style_img.convert("RGB"))
74
  style = style.to(device).unsqueeze(0)
75
  content = content.to(device).unsqueeze(0)
 
76
  with torch.no_grad():
77
  output= network(content,style)
78
  output = output[0].cpu().squeeze()