Sophie98 commited on
Commit
e4fb230
·
1 Parent(s): 79c6687

restructured code

Browse files
Files changed (6) hide show
  1. .gitignore +8 -0
  2. StyTR.py +1 -1
  3. app.py +72 -2
  4. segmentation.py +6 -8
  5. styleTransfer.py +108 -71
  6. test.py +0 -175
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ __pycache__/box_ops.cpython-37.pyc
2
+ __pycache__/function.cpython-37.pyc
3
+ __pycache__/misc.cpython-37.pyc
4
+ __pycache__/segmentation.cpython-37.pyc
5
+ __pycache__/styleTransfer.cpython-37.pyc
6
+ __pycache__/StyTR.cpython-37.pyc
7
+ __pycache__/transformer.cpython-37.pyc
8
+ __pycache__/ViT_helper.cpython-37.pyc
StyTR.py CHANGED
@@ -137,7 +137,7 @@ class MLP(nn.Module):
137
  class StyTrans(nn.Module):
138
  """ This is the style transform transformer module """
139
 
140
- def __init__(self,encoder,decoder,PatchEmbed, transformer,args):
141
 
142
  super().__init__()
143
  enc_layers = list(encoder.children())
 
137
  class StyTrans(nn.Module):
138
  """ This is the style transform transformer module """
139
 
140
+ def __init__(self,encoder,decoder,PatchEmbed, transformer):
141
 
142
  super().__init__()
143
  enc_layers = list(encoder.children())
app.py CHANGED
@@ -1,9 +1,78 @@
1
  import numpy as np
2
  import gradio as gr
3
  from segmentation import get_mask,replace_sofa
4
- from styleTransfer import resize_sofa,resize_style,create_styledSofa
5
  from PIL import Image
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  def style_sofa(input_img: np.ndarray, style_img: np.ndarray):
8
  """
9
  Styles (all) the sofas in the image to the given style.
@@ -17,6 +86,7 @@ def style_sofa(input_img: np.ndarray, style_img: np.ndarray):
17
  """
18
 
19
  # preprocess input images to be (640,640) squares to fit requirements of the segmentation model
 
20
  resized_img,box = resize_sofa(input_img)
21
  resized_style = resize_style(style_img)
22
  # generate mask for image
@@ -24,7 +94,7 @@ def style_sofa(input_img: np.ndarray, style_img: np.ndarray):
24
  styled_sofa = create_styledSofa(resized_img,resized_style)
25
  # postprocess the final image
26
  new_sofa = replace_sofa(resized_img,mask,styled_sofa)
27
- new_sofa = Image.fromarray(new_sofa).crop(box)
28
  return new_sofa
29
 
30
  image = gr.inputs.Image()
 
1
  import numpy as np
2
  import gradio as gr
3
  from segmentation import get_mask,replace_sofa
4
+ from styleTransfer import create_styledSofa
5
  from PIL import Image
6
 
7
+ def resize_sofa(img):
8
+ """
9
+ This function adds padding to make the orignal image square and 640by640.
10
+ It also returns the orignal ratio of the image, such that it can be reverted later.
11
+ Parameters:
12
+ img = original image
13
+ Return:
14
+ im1 = squared image
15
+ box = parameters to later crop the image to it original ratio
16
+ """
17
+ width, height = img.size
18
+ idx = np.argmin([width,height])
19
+ newsize = (640, 640) # parameters from test script
20
+
21
+ if idx==0:
22
+ img1 = Image.new(img.mode, (height, height), (255, 255, 255))
23
+ img1.paste(img, ((height-width)//2, 0))
24
+ box = ( newsize[0]*(1-width/height)//2,
25
+ 0,
26
+ newsize[0]-newsize[0]*(1-width/height)//2,
27
+ newsize[1])
28
+ else:
29
+ img1 = Image.new(img.mode, (width, width), (255, 255, 255))
30
+ img1.paste(img, (0, (width-height)//2))
31
+ box = (0,
32
+ newsize[1]*(1-height/width)//2,
33
+ newsize[0],
34
+ newsize[1]-newsize[1]*(1-height/width)//2)
35
+ im1 = img1.resize(newsize)
36
+ return im1,box
37
+
38
+ def resize_style(img):
39
+ """
40
+ This function generates a zoomed out version of the style image and resizes it to a 640by640 square.
41
+ Parameters:
42
+ img = image containing the style/pattern
43
+ Return:
44
+ dst = a zoomed-out and resized version of the pattern
45
+ """
46
+ width, height = img.size
47
+ idx = np.argmin([width,height])
48
+
49
+ # Makes the image square by cropping
50
+ if idx==0:
51
+ top= (height-width)//2
52
+ bottom= height-(height-width)//2
53
+ left = 0
54
+ right= width
55
+ else:
56
+ left = (width-height)//2
57
+ right = width - (width-height)//2
58
+ top = 0
59
+ bottom = height
60
+ newsize = (640, 640) # parameters from test script
61
+ im1 = img.crop((left, top, right, bottom))
62
+
63
+ # Constructs a zoomed-out version
64
+ copies = 8
65
+ resize = (newsize[0]//copies,newsize[1]//copies)
66
+ dst = Image.new('RGB', (resize[0]*copies,resize[1]*copies))
67
+ im2 = im1.resize((resize))
68
+ for row in range(copies):
69
+ im2 = im2.transpose(Image.FLIP_LEFT_RIGHT)
70
+ for column in range(copies):
71
+ im2 = im2.transpose(Image.FLIP_TOP_BOTTOM)
72
+ dst.paste(im2, (resize[0]*row, resize[1]*column))
73
+ dst = dst.resize((newsize))
74
+ return dst
75
+
76
  def style_sofa(input_img: np.ndarray, style_img: np.ndarray):
77
  """
78
  Styles (all) the sofas in the image to the given style.
 
86
  """
87
 
88
  # preprocess input images to be (640,640) squares to fit requirements of the segmentation model
89
+ input_img,style_img = Image.fromarray(input_img),Image.fromarray(style_img)
90
  resized_img,box = resize_sofa(input_img)
91
  resized_style = resize_style(style_img)
92
  # generate mask for image
 
94
  styled_sofa = create_styledSofa(resized_img,resized_style)
95
  # postprocess the final image
96
  new_sofa = replace_sofa(resized_img,mask,styled_sofa)
97
+ new_sofa = new_sofa.crop(box)
98
  return new_sofa
99
 
100
  image = gr.inputs.Image()
segmentation.py CHANGED
@@ -7,7 +7,7 @@ import matplotlib.pyplot as plt
7
  from PIL import Image
8
  import segmentation_models as sm
9
 
10
- def get_mask(image):
11
  """
12
  This function generates a mask of the image that highlights all the sofas in the image.
13
  This uses a pre-trained Unet model with a resnet50 backbone.
@@ -50,13 +50,11 @@ def get_mask(image):
50
  test_img = cv2.cvtColor(test_img, cv2.COLOR_RGB2BGR)
51
  test_img = np.expand_dims(test_img, axis=0)
52
 
53
- prediction = model.predict(preprocess_input(test_img)).round()
54
- print("generated mask")
55
  mask = Image.fromarray(prediction[...,0].squeeze()*255).convert("L")
56
- #mask.save("masks/sofa.jpg")
57
- return np.array(mask)
58
 
59
- def replace_sofa(image,mask,styled_sofa):
60
  """
61
  This function replaces the original sofa in the image by the new styled sofa according
62
  to the mask.
@@ -68,7 +66,7 @@ def replace_sofa(image,mask,styled_sofa):
68
  Return:
69
  new_image = Image containing the styled sofa
70
  """
71
- image = np.array(image)
72
  #image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
73
  styled_sofa = cv2.cvtColor(styled_sofa, cv2.COLOR_BGR2RGB)
74
 
@@ -77,7 +75,7 @@ def replace_sofa(image,mask,styled_sofa):
77
  image_bg = cv2.bitwise_and(image,image,mask = mask_inv)
78
  sofa_fg = cv2.bitwise_and(styled_sofa,styled_sofa,mask = mask)
79
  new_image = cv2.add(image_bg,sofa_fg)
80
- return new_image
81
 
82
  # image = cv2.imread('input/sofa.jpg')
83
  # mask = cv2.imread('masks/sofa.jpg')
 
7
  from PIL import Image
8
  import segmentation_models as sm
9
 
10
+ def get_mask(image:Image) -> Image:
11
  """
12
  This function generates a mask of the image that highlights all the sofas in the image.
13
  This uses a pre-trained Unet model with a resnet50 backbone.
 
50
  test_img = cv2.cvtColor(test_img, cv2.COLOR_RGB2BGR)
51
  test_img = np.expand_dims(test_img, axis=0)
52
 
53
+ prediction = model.predict(preprocess_input(np.array(test_img))).round()
 
54
  mask = Image.fromarray(prediction[...,0].squeeze()*255).convert("L")
55
+ return mask
 
56
 
57
+ def replace_sofa(image:Image, mask:Image, styled_sofa:Image) -> Image:
58
  """
59
  This function replaces the original sofa in the image by the new styled sofa according
60
  to the mask.
 
66
  Return:
67
  new_image = Image containing the styled sofa
68
  """
69
+ image,mask,styled_sofa = np.array(image),np.array(mask),np.array(styled_sofa)
70
  #image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
71
  styled_sofa = cv2.cvtColor(styled_sofa, cv2.COLOR_BGR2RGB)
72
 
 
75
  image_bg = cv2.bitwise_and(image,image,mask = mask_inv)
76
  sofa_fg = cv2.bitwise_and(styled_sofa,styled_sofa,mask = mask)
77
  new_image = cv2.add(image_bg,sofa_fg)
78
+ return Image.fromarray(new_image)
79
 
80
  # image = cv2.imread('input/sofa.jpg')
81
  # mask = cv2.imread('masks/sofa.jpg')
styleTransfer.py CHANGED
@@ -1,77 +1,114 @@
1
  from PIL import Image
2
  import numpy as np
3
- import os
4
- import cv2
5
-
6
- def resize_sofa(img):
7
- img = Image.fromarray(img)
8
- width, height = img.size
9
- idx = np.argmin([width,height])
10
- newsize = (640, 640) # parameters from test script
11
-
12
- if idx==0:
13
- img1 = Image.new(img.mode, (height, height), (255, 255, 255))
14
- img1.paste(img, ((height-width)//2, 0))
15
- box = ( newsize[0]*(1-width/height)//2,
16
- 0,
17
- newsize[0]-newsize[0]*(1-width/height)//2,
18
- newsize[1])
19
- else:
20
- img1 = Image.new(img.mode, (width, width), (255, 255, 255))
21
- img1.paste(img, (0, (width-height)//2))
22
- box = (0,
23
- newsize[1]*(1-height/width)//2,
24
- newsize[0],
25
- newsize[1]-newsize[1]*(1-height/width)//2)
26
- im1 = img1.resize(newsize)
27
- return im1,box
28
-
29
- def resize_style(img):
30
- #img = Image.open(path)#"../style5.jpg")
31
- img = Image.fromarray(img)
32
- width, height = img.size
33
- idx = np.argmin([width,height])
34
- #print(width,height)
35
-
36
- if idx==0:
37
- top= (height-width)//2
38
- bottom= height-(height-width)//2
39
- left = 0
40
- right= width
41
- else:
42
- left = (width-height)//2
43
- right = width - (width-height)//2
44
- top = 0
45
- bottom = height
46
-
47
- newsize = (640, 640) # parameters from test script
48
- im1 = img.crop((left, top, right, bottom))
49
-
50
- copies = 8
51
- resize = (newsize[0]//copies,newsize[1]//copies)
52
- dst = Image.new('RGB', (resize[0]*copies,resize[1]*copies))
53
- im2 = im1.resize((resize))
54
- for row in range(copies):
55
- im2 = im2.transpose(Image.FLIP_LEFT_RIGHT)
56
- for column in range(copies):
57
- im2 = im2.transpose(Image.FLIP_TOP_BOTTOM)
58
- dst.paste(im2, (resize[0]*row, resize[1]*column))
59
- dst = dst.resize((newsize))
60
- return dst
61
-
62
- def create_styledSofa(sofa,style):
63
- path_sofa,path_style = 'sofa.jpg','style.jpg'
64
- sofa.save(path_sofa)
65
- style.save(path_style)
66
- os.system("python3 test.py --content "+path_sofa+" \
67
- --style "+path_style+" \
68
- --output . \
69
- --vgg vgg_normalised.pth \
70
- --decoder_path decoder_iter_160000.pth \
71
- --Trans_path transformer_iter_160000.pth \
72
- --embedding_path embedding_iter_160000.pth")
73
- styled_sofa = cv2.imread('sofa_stylized_style.jpg')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
 
 
 
 
 
 
 
 
75
  return styled_sofa
76
 
77
  # image = Image.open('sofa_office.jpg')
 
1
  from PIL import Image
2
  import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ from PIL import Image
6
+ from torchvision import transforms
7
+ import transformer as transformer
8
+ import StyTR as StyTR
9
+ import numpy as np
10
+ from collections import OrderedDict
11
+
12
+ def test_transform(size, crop):
13
+ transform_list = []
14
+
15
+ if size != 0:
16
+ transform_list.append(transforms.Resize(size))
17
+ if crop:
18
+ transform_list.append(transforms.CenterCrop(size))
19
+ transform_list.append(transforms.ToTensor())
20
+ transform = transforms.Compose(transform_list)
21
+ return transform
22
+ def style_transform(h,w):
23
+ k = (h,w)
24
+ size = int(np.max(k))
25
+ transform_list = []
26
+ transform_list.append(transforms.CenterCrop((h,w)))
27
+ transform_list.append(transforms.ToTensor())
28
+ transform = transforms.Compose(transform_list)
29
+ return transform
30
+
31
+ def content_transform():
32
+ transform_list = []
33
+ transform_list.append(transforms.ToTensor())
34
+ transform = transforms.Compose(transform_list)
35
+ return transform
36
+
37
+ def Transformer(content_img: Image, style_img: Image,
38
+ vgg_path:str = 'vgg_normalised.pth', decoder_path:str = 'decoder_iter_160000.pth',
39
+ Trans_path:str = 'transformer_iter_160000.pth', embedding_path:str = 'embedding_iter_160000.pth'):
40
+ # Advanced options
41
+ content_size=640
42
+ style_size=640
43
+ crop='store_true'
44
+
45
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
+
47
+ vgg = StyTR.vgg
48
+ vgg.load_state_dict(torch.load(vgg_path))
49
+ vgg = nn.Sequential(*list(vgg.children())[:44])
50
+
51
+ decoder = StyTR.decoder
52
+ Trans = transformer.Transformer()
53
+ embedding = StyTR.PatchEmbed()
54
+
55
+ decoder.eval()
56
+ Trans.eval()
57
+ vgg.eval()
58
+
59
+ new_state_dict = OrderedDict()
60
+ state_dict = torch.load(decoder_path)
61
+ for k, v in state_dict.items():
62
+ #namekey = k[7:] # remove `module.`
63
+ namekey = k
64
+ new_state_dict[namekey] = v
65
+ decoder.load_state_dict(new_state_dict)
66
+
67
+ new_state_dict = OrderedDict()
68
+ state_dict = torch.load(Trans_path)
69
+ for k, v in state_dict.items():
70
+ #namekey = k[7:] # remove `module.`
71
+ namekey = k
72
+ new_state_dict[namekey] = v
73
+ Trans.load_state_dict(new_state_dict)
74
+
75
+ new_state_dict = OrderedDict()
76
+ state_dict = torch.load(embedding_path)
77
+ for k, v in state_dict.items():
78
+ #namekey = k[7:] # remove `module.`
79
+ namekey = k
80
+ new_state_dict[namekey] = v
81
+ embedding.load_state_dict(new_state_dict)
82
+
83
+ network = StyTR.StyTrans(vgg,decoder,embedding,Trans)
84
+ network.eval()
85
+ network.to(device)
86
+
87
+
88
+
89
+ content_tf = test_transform(content_size, crop)
90
+ style_tf = test_transform(style_size, crop)
91
+
92
+
93
+ content_tf1 = content_transform()
94
+ content = content_tf(content_img.convert("RGB"))
95
+
96
+ h,w,c=np.shape(content)
97
+ style_tf1 = style_transform(h,w)
98
+ style = style_tf(style_img.convert("RGB"))
99
+
100
+
101
+ style = style.to(device).unsqueeze(0)
102
+ content = content.to(device).unsqueeze(0)
103
 
104
+ with torch.no_grad():
105
+ output= network(content,style)
106
+ output = output[0].cpu()
107
+ output = transforms.ToPILImage(output)
108
+ return output
109
+
110
+ def create_styledSofa(sofa:Image, style:Image):
111
+ styled_sofa = transformer(sofa,style)
112
  return styled_sofa
113
 
114
  # image = Image.open('sofa_office.jpg')
test.py DELETED
@@ -1,175 +0,0 @@
1
- import argparse
2
- from pathlib import Path
3
- import os
4
- import torch
5
- import torch.nn as nn
6
- from PIL import Image
7
- from os.path import basename
8
- from os.path import splitext
9
- from torchvision import transforms
10
- from torchvision.utils import save_image
11
- from function import calc_mean_std, normal, coral
12
- import transformer as transformer
13
- import StyTR as StyTR
14
- import matplotlib.pyplot as plt
15
- from matplotlib import cm
16
- from function import normal
17
- import numpy as np
18
-
19
- def test_transform(size, crop):
20
- transform_list = []
21
-
22
- if size != 0:
23
- transform_list.append(transforms.Resize(size))
24
- if crop:
25
- transform_list.append(transforms.CenterCrop(size))
26
- transform_list.append(transforms.ToTensor())
27
- transform = transforms.Compose(transform_list)
28
- return transform
29
- def style_transform(h,w):
30
- k = (h,w)
31
- size = int(np.max(k))
32
- transform_list = []
33
- transform_list.append(transforms.CenterCrop((h,w)))
34
- transform_list.append(transforms.ToTensor())
35
- transform = transforms.Compose(transform_list)
36
- return transform
37
-
38
- def content_transform():
39
-
40
- transform_list = []
41
- transform_list.append(transforms.ToTensor())
42
- transform = transforms.Compose(transform_list)
43
- return transform
44
-
45
-
46
-
47
- parser = argparse.ArgumentParser()
48
- # Basic options
49
- parser.add_argument('--content', type=str,
50
- help='File path to the content image')
51
- parser.add_argument('--content_dir', type=str,
52
- help='Directory path to a batch of content images')
53
- parser.add_argument('--style', type=str,
54
- help='File path to the style image, or multiple style \
55
- images separated by commas if you want to do style \
56
- interpolation or spatial control')
57
- parser.add_argument('--style_dir', type=str,
58
- help='Directory path to a batch of style images')
59
- parser.add_argument('--output', type=str, default='output',
60
- help='Directory to save the output image(s)')
61
- parser.add_argument('--vgg', type=str, default='./experiments/vgg_normalised.pth')
62
- parser.add_argument('--decoder_path', type=str, default='experiments/decoder_iter_160000.pth')
63
- parser.add_argument('--Trans_path', type=str, default='experiments/transformer_iter_160000.pth')
64
- parser.add_argument('--embedding_path', type=str, default='experiments/embedding_iter_160000.pth')
65
-
66
-
67
- parser.add_argument('--style_interpolation_weights', type=str, default="")
68
- parser.add_argument('--a', type=float, default=1.0)
69
- parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
70
- help="Type of positional embedding to use on top of the image features")
71
- parser.add_argument('--hidden_dim', default=512, type=int,
72
- help="Size of the embeddings (dimension of the transformer)")
73
- args = parser.parse_args()
74
-
75
- # Advanced options
76
- content_size=640
77
- style_size=640
78
- crop='store_true'
79
- save_ext='.jpg'
80
- output_path=args.output
81
- preserve_color='store_true'
82
- alpha=args.a
83
-
84
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
85
-
86
- # Either --content or --content_dir should be given.
87
- if args.content:
88
- content_paths = [Path(args.content)]
89
- else:
90
- content_dir = Path(args.content_dir)
91
- content_paths = [f for f in content_dir.glob('*')]
92
-
93
- # Either --style or --style_dir should be given.
94
- if args.style:
95
- style_paths = [Path(args.style)]
96
- else:
97
- style_dir = Path(args.style_dir)
98
- style_paths = [f for f in style_dir.glob('*')]
99
-
100
- if not os.path.exists(output_path):
101
- os.mkdir(output_path)
102
-
103
-
104
- vgg = StyTR.vgg
105
- vgg.load_state_dict(torch.load(args.vgg))
106
- vgg = nn.Sequential(*list(vgg.children())[:44])
107
-
108
- decoder = StyTR.decoder
109
- Trans = transformer.Transformer()
110
- embedding = StyTR.PatchEmbed()
111
-
112
- decoder.eval()
113
- Trans.eval()
114
- vgg.eval()
115
- from collections import OrderedDict
116
- new_state_dict = OrderedDict()
117
- state_dict = torch.load(args.decoder_path)
118
- for k, v in state_dict.items():
119
- #namekey = k[7:] # remove `module.`
120
- namekey = k
121
- new_state_dict[namekey] = v
122
- decoder.load_state_dict(new_state_dict)
123
-
124
- new_state_dict = OrderedDict()
125
- state_dict = torch.load(args.Trans_path)
126
- for k, v in state_dict.items():
127
- #namekey = k[7:] # remove `module.`
128
- namekey = k
129
- new_state_dict[namekey] = v
130
- Trans.load_state_dict(new_state_dict)
131
-
132
- new_state_dict = OrderedDict()
133
- state_dict = torch.load(args.embedding_path)
134
- for k, v in state_dict.items():
135
- #namekey = k[7:] # remove `module.`
136
- namekey = k
137
- new_state_dict[namekey] = v
138
- embedding.load_state_dict(new_state_dict)
139
-
140
- network = StyTR.StyTrans(vgg,decoder,embedding,Trans,args)
141
- network.eval()
142
- network.to(device)
143
-
144
-
145
-
146
- content_tf = test_transform(content_size, crop)
147
- style_tf = test_transform(style_size, crop)
148
-
149
- for content_path in content_paths:
150
- for style_path in style_paths:
151
-
152
-
153
- content_tf1 = content_transform()
154
- content = content_tf(Image.open(content_path).convert("RGB"))
155
-
156
- h,w,c=np.shape(content)
157
- style_tf1 = style_transform(h,w)
158
- style = style_tf(Image.open(style_path).convert("RGB"))
159
-
160
-
161
- style = style.to(device).unsqueeze(0)
162
- content = content.to(device).unsqueeze(0)
163
-
164
- with torch.no_grad():
165
- output= network(content,style)
166
- output = output[0].cpu()
167
-
168
- output_name = '{:s}/{:s}_stylized_{:s}{:s}'.format(
169
- output_path, splitext(basename(content_path))[0],
170
- splitext(basename(style_path))[0], save_ext
171
- )
172
-
173
- save_image(output, output_name)
174
-
175
-