sunshineatnoon commited on
Commit
1d90a68
1 Parent(s): d4e058e

Add application file

Browse files
Files changed (1) hide show
  1. app.py +307 -0
app.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import json
4
+ import base64
5
+ import argparse
6
+ import importlib
7
+ from glob import glob
8
+ from PIL import Image
9
+ from imageio import imsave
10
+
11
+ import torch
12
+ import torchvision.utils as vutils
13
+
14
+ import sys
15
+ sys.path.append(".")
16
+
17
+ import numpy as np
18
+ from libs.test_base import TesterBase
19
+ from libs.utils import colorEncode, label2one_hot_torch
20
+ from tqdm import tqdm
21
+ from libs.options import BaseOptions
22
+ from skimage.segmentation import mark_boundaries
23
+ import torch.nn.functional as F
24
+ from libs.nnutils import poolfeat, upfeat
25
+
26
+ import streamlit as st
27
+ from skimage.segmentation import slic
28
+ import torchvision.transforms.functional as TF
29
+ import torchvision.transforms as transforms
30
+ from st_clickable_images import clickable_images
31
+
32
+ args = BaseOptions().gather_options()
33
+ if args.img_path is not None:
34
+ args.exp_name = os.path.join(args.exp_name, args.img_path.split('/')[-1].split('.')[0])
35
+ args.batch_size = 1
36
+ args.data_path = "/home/xli/DATA/BSR_processed/train"
37
+ args.label_path = "/home/xli/DATA/BSR/BSDS500/data/groundTruth"
38
+ args.device = torch.device("cpu")
39
+ args.nsamples = 500
40
+ args.out_dir = os.path.join('cachedir', args.exp_name)
41
+ os.makedirs(args.out_dir, exist_ok=True)
42
+ args.global_code_ch = args.hidden_dim
43
+ args.netG_use_noise = True
44
+ args.test_time = (args.test_time == 1)
45
+
46
+ if not hasattr(args, 'tex_code_dim'):
47
+ args.tex_code_dim = 256
48
+
49
+ class Tester(TesterBase):
50
+ def define_model(self):
51
+ """Define model
52
+ """
53
+ args = self.args
54
+ module = importlib.import_module('models.week0417.{}'.format(args.model_name))
55
+ self.model = module.AE(args)
56
+ self.model.to(args.device)
57
+ self.model.eval()
58
+ return
59
+
60
+ def draw_color_seg(self, seg):
61
+ seg = seg.detach().cpu().numpy()
62
+ color_ = []
63
+ for i in range(seg.shape[0]):
64
+ colori = colorEncode(seg[i].squeeze())
65
+ colori = torch.from_numpy(colori / 255.0).float().permute(2, 0, 1)
66
+ color_.append(colori)
67
+ color_ = torch.stack(color_)
68
+ return color_
69
+
70
+ def to_pil(self, tensor):
71
+ return transforms.ToPILImage()(tensor.cpu().squeeze().clamp(0.0, 1.0)).convert("RGB")
72
+
73
+ def display(self):
74
+ with st.spinner('Running...'):
75
+ with torch.no_grad():
76
+ grouping_mask = self.model_forward(self.data, self.slic, return_type = 'grouping')
77
+
78
+ data = (self.data + 1) / 2.0
79
+
80
+ seg = grouping_mask.view(-1, 1, args.crop_size, args.crop_size)
81
+ color_vq = self.draw_color_seg(seg)
82
+ color_vq = color_vq * 0.8 + data.cpu() * 0.2
83
+
84
+ st.markdown('<p class="big-font">Given the image you chose, our model decomposes the image into ten texture segments, each depicts one kind of texture in the image.</p>', unsafe_allow_html=True)
85
+ col1, col2, col3, col4 = st.columns(4)
86
+ with col1:
87
+ st.markdown("")
88
+
89
+ with col2:
90
+ st.markdown("Chosen image")
91
+ st.image(self.to_pil(data))
92
+
93
+ with col3:
94
+ st.markdown("Grouping mask")
95
+ st.image(self.to_pil(color_vq))
96
+
97
+ with col4:
98
+ st.markdown("")
99
+
100
+ seg_onehot = label2one_hot_torch(seg, C = 10)
101
+ parts = data.cpu() * seg_onehot.squeeze().unsqueeze(1)
102
+
103
+ st.markdown('<p class="big-font">We show all texture segments below. To synthesize an arbitrary-sized texture image from a texture segment, choose and click one of the texture segments below.</p>', unsafe_allow_html=True)
104
+ tmp_img_list = []
105
+ for i in range(parts.shape[0]):
106
+ part_img = self.to_pil(parts[i])
107
+ out_path = '/home/xli/Dropbox/PAS/tmp/{}.png'.format(i)
108
+ part_img.save(out_path)
109
+
110
+ with open(out_path, "rb") as image:
111
+ encoded = base64.b64encode(image.read()).decode()
112
+ tmp_img_list.append(f"data:image/jpeg;base64,{encoded}")
113
+
114
+ tex_idx = clickable_images(
115
+ tmp_img_list,
116
+ titles=[f"Group #{str(i)}" for i in range(len(tmp_img_list))],
117
+ div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"},
118
+ img_style={"margin": "5px", "height": "150px"},
119
+ key=0
120
+ )
121
+
122
+ if tex_idx > -1:
123
+ with st.spinner('Running...'):
124
+ st.markdown('<p class="big-font">You can slide the bar below to set the size of the synthesized texture image.</p>', unsafe_allow_html=True)
125
+ tex_size = st.slider('', 0, 1000, 256)
126
+ tex_size = (tex_size // 8) * 8
127
+ with torch.no_grad():
128
+ tex = self.model_forward(self.data, self.slic, tex_idx = tex_idx, tex_size = tex_size, return_type = 'tex')
129
+ col1, col2, col3, col4 = st.columns([1, 1, 4, 1])
130
+ with col1:
131
+ st.markdown("")
132
+
133
+ with col2:
134
+ st.markdown("Chosen examplar segment")
135
+ st.image(self.to_pil(parts[tex_idx]))
136
+
137
+ with col3:
138
+ st.markdown("Synthesized texture image")
139
+ st.image(self.to_pil(tex))
140
+
141
+ with col4:
142
+ st.markdown("")
143
+ st.markdown('<p class="big-font">You can choose another image from the examplar images on the top and start again!</p>', unsafe_allow_html=True)
144
+ #torch.cuda.empty_cache()
145
+
146
+ """
147
+ st.markdown("#### Texture Editing")
148
+ st.markdown("**Choose one texture segment to remove.**")
149
+ remove_idx = clickable_images(
150
+ tmp_img_list,
151
+ titles=[f"Group #{str(i)}" for i in range(len(tmp_img_list))],
152
+ div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"},
153
+ img_style={"margin": "5px", "height": "120px"},
154
+ key=1
155
+ )
156
+ st.markdown("**Choose one texture segment to fill in the missing pixels.**")
157
+ fill_idx = clickable_images(
158
+ tmp_img_list,
159
+ titles=[f"Group #{str(i)}" for i in range(len(tmp_img_list))],
160
+ div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"},
161
+ img_style={"margin": "5px", "height": "120px"},
162
+ key=2
163
+ )
164
+ rec = self.model_forward(self.data, self.slic, return_type = 'editing', fill_idx = fill_idx, remove_idx = remove_idx)
165
+ st.image(self.to_pil(rec))
166
+ """
167
+
168
+ def model_forward(self, rgb_img, slic, epoch = 1000, test_time = False,
169
+ test = True, tex_idx = None, tex_size = 256,
170
+ return_type = 'tex', fill_idx = None, remove_idx = None):
171
+ args = self.args
172
+ B, _, imgH, imgW = rgb_img.shape
173
+
174
+ # Encoder: img (B, 3, H, W) -> feature (B, C, imgH//8, imgW//8)
175
+ conv_feat, _ = self.model.enc(rgb_img)
176
+ B, C, H, W = conv_feat.shape
177
+
178
+ # Texture code for each superpixel
179
+ tex_code = self.model.ToTexCode(conv_feat)
180
+
181
+ code = F.interpolate(tex_code, size = (imgH, imgW), mode = 'bilinear', align_corners = False)
182
+ pool_code = poolfeat(code, slic, avg = True)
183
+
184
+ prop_code, sp_assign, conv_feats = self.model.gcn(pool_code, slic, (args.add_clustering_epoch <= epoch))
185
+ softmax = F.softmax(sp_assign * args.temperature, dim = 1)
186
+ if return_type == 'grouping':
187
+ return torch.argmax(sp_assign.cpu(), dim = 1)
188
+
189
+
190
+ tex_seg = poolfeat(conv_feats, softmax, avg = True)
191
+ seg = label2one_hot_torch(torch.argmax(softmax, dim = 1).unsqueeze(1), C = softmax.shape[1])
192
+
193
+ if return_type == 'tex':
194
+ sampled_code = tex_seg[:, tex_idx, :]
195
+ rec_tex = sampled_code.view(1, -1, 1, 1).repeat(1, 1, tex_size, tex_size)
196
+ sine_wave = self.model.get_sine_wave(rec_tex, 'rec')
197
+ H = tex_size // 8; W = tex_size // 8
198
+ noise = torch.randn(B, self.model.sine_wave_dim, H, W).to(tex_code.device)
199
+ dec_input = torch.cat((sine_wave, noise), dim = 1)
200
+
201
+ weight = self.model.ChannelWeight(rec_tex)
202
+ weight = F.adaptive_avg_pool2d(weight, output_size = (1)).view(weight.shape[0], -1, 1, 1)
203
+ weight = torch.sigmoid(weight)
204
+ dec_input *= weight
205
+
206
+ rep_rec = self.model.G(dec_input, rec_tex)
207
+ rep_rec = (rep_rec + 1) / 2.0
208
+ return rep_rec
209
+ elif return_type == 'editing':
210
+ remove_mask = 0
211
+ fill_mask = 1
212
+ rec_tex = upfeat(tex_seg, seg)
213
+ remove_mask = seg[:, remove_idx:remove_idx+1]
214
+ fill_tex = tex_seg[:, fill_idx, :].view(1, -1, 1, 1).repeat(1, 1, imgH, imgW)
215
+ rec_tex = rec_tex * (1 - remove_mask) + fill_tex * remove_mask
216
+
217
+ sine_wave = self.model.get_sine_wave(rec_tex, 'rec')
218
+ H = imgH // 8; W = imgW // 8
219
+ noise = torch.randn(B, self.model.sine_wave_dim, H, W).to(tex_code.device)
220
+ dec_input = torch.cat((sine_wave, noise), dim = 1)
221
+ weight = self.model.ChannelWeight(rec_tex)
222
+ weight = F.adaptive_avg_pool2d(weight, output_size = (1)).view(weight.shape[0], -1, 1, 1)
223
+ weight = torch.sigmoid(weight)
224
+ dec_input *= weight
225
+
226
+ rep_rec = self.model.G(dec_input, rec_tex)
227
+ rep_rec = (rep_rec + 1) / 2.0
228
+ return rep_rec
229
+
230
+
231
+ def load_data(self, data_path):
232
+ rgb_img = Image.open(data_path)
233
+ crop_size = self.args.crop_size
234
+ i = 40; j = 40; h = crop_size; w = crop_size
235
+ rgb_img = TF.crop(rgb_img, i, j, h, w)
236
+
237
+ # compute superpixel
238
+ sp_num = 196
239
+ slic_i = slic(np.array(rgb_img), n_segments=sp_num, compactness=10, start_label=0, min_size_factor=0.3)
240
+ slic_i = torch.from_numpy(slic_i)
241
+ slic_i[slic_i >= sp_num] = sp_num - 1
242
+ oh = label2one_hot_torch(slic_i.unsqueeze(0).unsqueeze(0), C = sp_num).squeeze()
243
+ self.slic = oh.unsqueeze(0).to(args.device)
244
+
245
+ rgb_img = TF.to_tensor(rgb_img)
246
+ rgb_img = rgb_img.unsqueeze(0)
247
+ self.data = rgb_img.to(args.device) * 2 - 1
248
+
249
+ def load_model(self, model_path):
250
+ self.model = torch.nn.DataParallel(self.model)
251
+ cpk = torch.load(model_path)
252
+ saved_state_dict = cpk['model']
253
+ self.model.load_state_dict(saved_state_dict)
254
+ self.model = self.model.module
255
+ return
256
+
257
+ def test(self):
258
+ """ Test function
259
+ """
260
+ #for iteration in tqdm(range(args.nsamples)):
261
+ self.test_step(0)
262
+ self.display(0, 'train')
263
+
264
+ def main():
265
+ #torch.cuda.empty_cache()
266
+ st.set_page_config(layout="wide")
267
+ st.markdown("""
268
+ <style>
269
+ .big-font {
270
+ font-size:30px !important;
271
+ }
272
+ </style>
273
+ """, unsafe_allow_html=True)
274
+
275
+ st.title("Scraping Textures from Natural Images for Synthesis and Editing")
276
+ #st.markdown("**In this demo, we show how to scrape textures from natural images for texture synthesis and editing.**")
277
+ st.markdown('<p class="big-font">In this demo, we show how to scrape textures from natural images for texture synthesis and editing.</p>', unsafe_allow_html=True)
278
+ st.markdown("## Texture synthesis")
279
+ st.markdown('<p class="big-font">Here we provide a set of example images, please choose and click one image to start.</p>', unsafe_allow_html=True)
280
+ img_list = glob(os.path.join("data/images/*.jpg"))
281
+ test_img_list = glob(os.path.join("data/test_images/*.jpg"))
282
+ img_list.extend(test_img_list)
283
+ byte_img_list = []
284
+ for img_path in img_list:
285
+ with open(img_path, "rb") as image:
286
+ encoded = base64.b64encode(image.read()).decode()
287
+ byte_img_list.append(f"data:image/jpeg;base64,{encoded}")
288
+ img_idx = clickable_images(
289
+ byte_img_list,
290
+ titles=[f"Group #{str(i)}" for i in range(len(byte_img_list))],
291
+ div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"},
292
+ img_style={"margin": "5px", "height": "150px"},
293
+ )
294
+ img_path = img_list[img_idx]
295
+
296
+ img_name = img_path.split("/")[-1]
297
+ args.pretrained_path = os.path.join("/home/xli/WORKDIR/04-18/{}/cpk.pth".format(img_name.split(".")[0]))
298
+
299
+ if img_idx > -1:
300
+ tester = Tester(args)
301
+ tester.define_model()
302
+ tester.load_data(img_path)
303
+ tester.load_model(args.pretrained_path)
304
+ tester.display()
305
+
306
+ if __name__ == '__main__':
307
+ main()