Sophie98 commited on
Commit
8e6efc1
1 Parent(s): a402a5e

make transformer global

Browse files
Files changed (2) hide show
  1. app.py +3 -1
  2. styleTransfer.py +67 -72
app.py CHANGED
@@ -90,10 +90,12 @@ def style_sofa(input_img: np.ndarray, style_img: np.ndarray):
90
  input_img,style_img = Image.fromarray(input_img),Image.fromarray(style_img)
91
  resized_img,box = resize_sofa(input_img)
92
  resized_style = resize_style(style_img)
 
93
  # generate mask for image
94
  mask = get_mask(resized_img)
 
95
  styled_sofa = create_styledSofa(resized_img,resized_style)
96
- print(type(styled_sofa))
97
  # postprocess the final image
98
  new_sofa = replace_sofa(resized_img,mask,styled_sofa)
99
  new_sofa = new_sofa.crop(box)
90
  input_img,style_img = Image.fromarray(input_img),Image.fromarray(style_img)
91
  resized_img,box = resize_sofa(input_img)
92
  resized_style = resize_style(style_img)
93
+ resized_style.save('resized_style.jpg')
94
  # generate mask for image
95
  mask = get_mask(resized_img)
96
+ mask.save('mask.jpg')
97
  styled_sofa = create_styledSofa(resized_img,resized_style)
98
+ styled_sofa.save('styled_sofa.jpg')
99
  # postprocess the final image
100
  new_sofa = replace_sofa(resized_img,mask,styled_sofa)
101
  new_sofa = new_sofa.crop(box)
styleTransfer.py CHANGED
@@ -11,10 +11,15 @@ from collections import OrderedDict
11
  import tensorflow_hub as hub
12
  import tensorflow as tf
13
 
14
- style_transfer_model = hub.load("https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2")
15
-
16
 
17
  ############################################# TRANSFORMER ############################################
 
 
 
 
 
 
18
  def test_transform(size, crop):
19
  transform_list = []
20
 
@@ -40,82 +45,75 @@ def content_transform():
40
  transform = transforms.Compose(transform_list)
41
  return transform
42
 
43
- def StyleTransformer(content_img: Image, style_img: Image,
44
- vgg_path:str = 'vgg_normalised.pth', decoder_path:str = 'decoder_iter_160000.pth',
45
- Trans_path:str = 'transformer_iter_160000.pth', embedding_path:str = 'embedding_iter_160000.pth'):
46
- # Advanced options
47
- content_size=640
48
- style_size=640
49
- crop='store_true'
50
-
51
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52
-
53
- vgg = StyTR.vgg
54
- vgg.load_state_dict(torch.load(vgg_path))
55
- vgg = nn.Sequential(*list(vgg.children())[:44])
56
-
57
- decoder = StyTR.decoder
58
- Trans = transformer.Transformer()
59
- embedding = StyTR.PatchEmbed()
60
-
61
- decoder.eval()
62
- Trans.eval()
63
- vgg.eval()
64
-
65
- new_state_dict = OrderedDict()
66
- state_dict = torch.load(decoder_path)
67
- for k, v in state_dict.items():
68
- #namekey = k[7:] # remove `module.`
69
- namekey = k
70
- new_state_dict[namekey] = v
71
- decoder.load_state_dict(new_state_dict)
72
-
73
- new_state_dict = OrderedDict()
74
- state_dict = torch.load(Trans_path)
75
- for k, v in state_dict.items():
76
- #namekey = k[7:] # remove `module.`
77
- namekey = k
78
- new_state_dict[namekey] = v
79
- Trans.load_state_dict(new_state_dict)
80
-
81
- new_state_dict = OrderedDict()
82
- state_dict = torch.load(embedding_path)
83
- for k, v in state_dict.items():
84
- #namekey = k[7:] # remove `module.`
85
- namekey = k
86
- new_state_dict[namekey] = v
87
- embedding.load_state_dict(new_state_dict)
88
-
89
- network = StyTR.StyTrans(vgg,decoder,embedding,Trans)
90
- network.eval()
91
- network.to(device)
92
-
93
- content_tf = test_transform(content_size, crop)
94
- style_tf = test_transform(style_size, crop)
95
-
96
  content_tf1 = content_transform()
97
  content = content_tf(content_img.convert("RGB"))
98
-
99
  h,w,c=np.shape(content)
100
  style_tf1 = style_transform(h,w)
101
  style = style_tf(style_img.convert("RGB"))
102
-
103
-
104
  style = style.to(device).unsqueeze(0)
105
  content = content.to(device).unsqueeze(0)
106
 
107
  with torch.no_grad():
108
  output= network(content,style)
109
- print(type(output))
110
- output = output[0].cpu()
111
- print(type(output))
112
- print(output.squeeze().shape)
113
- torch2PIL = transforms.ToPILImage()
114
- output = torch2PIL(output.squeeze())
115
- return output
116
 
117
  ############################################## STYLE-GAN #############################################
118
 
 
 
119
  def perform_style_transfer(content_image, style_image):
120
  content_image = tf.convert_to_tensor(content_image, np.float32)[tf.newaxis, ...] / 255.
121
  style_image = tf.convert_to_tensor(style_image, np.float32)[tf.newaxis, ...] / 255.
@@ -130,10 +128,7 @@ def create_styledSofa(sofa:Image, style:Image):
130
  styled_sofa = StyleTransformer(sofa,style)
131
  return styled_sofa
132
 
133
- # image = Image.open('sofa_office.jpg')
134
- # image.show()
135
- # image = np.array(image)
136
- # image,box = resize_sofa(image)
137
- # image = image.crop(box)
138
- # print(box)
139
- # image.show()
11
  import tensorflow_hub as hub
12
  import tensorflow as tf
13
 
14
+ from torchvision.utils import save_image
 
15
 
16
  ############################################# TRANSFORMER ############################################
17
+
18
+ vgg_path = 'vgg_normalised.pth'
19
+ decoder_path = 'decoder_iter_160000.pth'
20
+ Trans_path = 'transformer_iter_160000.pth'
21
+ embedding_path = 'embedding_iter_160000.pth'
22
+
23
  def test_transform(size, crop):
24
  transform_list = []
25
 
45
  transform = transforms.Compose(transform_list)
46
  return transform
47
 
48
+ # Advanced options
49
+ content_size=640
50
+ style_size=640
51
+ crop='store_true'
52
+
53
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
+
55
+ vgg = StyTR.vgg
56
+ vgg.load_state_dict(torch.load(vgg_path))
57
+ vgg = nn.Sequential(*list(vgg.children())[:44])
58
+
59
+ decoder = StyTR.decoder
60
+ Trans = transformer.Transformer()
61
+ embedding = StyTR.PatchEmbed()
62
+
63
+ decoder.eval()
64
+ Trans.eval()
65
+ vgg.eval()
66
+
67
+ new_state_dict = OrderedDict()
68
+ state_dict = torch.load(decoder_path)
69
+ for k, v in state_dict.items():
70
+ #namekey = k[7:] # remove `module.`
71
+ namekey = k
72
+ new_state_dict[namekey] = v
73
+ decoder.load_state_dict(new_state_dict)
74
+
75
+ new_state_dict = OrderedDict()
76
+ state_dict = torch.load(Trans_path)
77
+ for k, v in state_dict.items():
78
+ #namekey = k[7:] # remove `module.`
79
+ namekey = k
80
+ new_state_dict[namekey] = v
81
+ Trans.load_state_dict(new_state_dict)
82
+
83
+ new_state_dict = OrderedDict()
84
+ state_dict = torch.load(embedding_path)
85
+ for k, v in state_dict.items():
86
+ #namekey = k[7:] # remove `module.`
87
+ namekey = k
88
+ new_state_dict[namekey] = v
89
+ embedding.load_state_dict(new_state_dict)
90
+
91
+ network = StyTR.StyTrans(vgg,decoder,embedding,Trans)
92
+ network.eval()
93
+ network.to(device)
94
+
95
+ content_tf = test_transform(content_size, crop)
96
+ style_tf = test_transform(style_size, crop)
97
+
98
+ def StyleTransformer(content_img: Image, style_img: Image):
 
 
99
  content_tf1 = content_transform()
100
  content = content_tf(content_img.convert("RGB"))
 
101
  h,w,c=np.shape(content)
102
  style_tf1 = style_transform(h,w)
103
  style = style_tf(style_img.convert("RGB"))
 
 
104
  style = style.to(device).unsqueeze(0)
105
  content = content.to(device).unsqueeze(0)
106
 
107
  with torch.no_grad():
108
  output= network(content,style)
109
+ output = output[0].cpu().squeeze()
110
+ output = output.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
111
+ return Image.fromarray(output)
 
 
 
 
112
 
113
  ############################################## STYLE-GAN #############################################
114
 
115
+ style_transfer_model = hub.load("https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2")
116
+
117
  def perform_style_transfer(content_image, style_image):
118
  content_image = tf.convert_to_tensor(content_image, np.float32)[tf.newaxis, ...] / 255.
119
  style_image = tf.convert_to_tensor(style_image, np.float32)[tf.newaxis, ...] / 255.
128
  styled_sofa = StyleTransformer(sofa,style)
129
  return styled_sofa
130
 
131
+ image = Image.open('sofa.jpg')
132
+ style = Image.open('style.jpg')
133
+ output = create_styledSofa(image,style)
134
+ output.show()