import os import time import json import base64 import argparse import importlib from glob import glob from PIL import Image from imageio import imsave import torch import torchvision.utils as vutils import sys sys.path.append(".") import numpy as np from libs.test_base import TesterBase from libs.utils import colorEncode, label2one_hot_torch from tqdm import tqdm from libs.options import BaseOptions import torch.nn.functional as F from libs.nnutils import poolfeat, upfeat import streamlit as st from skimage.segmentation import slic import torchvision.transforms.functional as TF import torchvision.transforms as transforms from st_clickable_images import clickable_images args = BaseOptions().gather_options() if args.img_path is not None: args.exp_name = os.path.join(args.exp_name, args.img_path.split('/')[-1].split('.')[0]) args.batch_size = 1 args.data_path = "/home/xli/DATA/BSR_processed/train" args.label_path = "/home/xli/DATA/BSR/BSDS500/data/groundTruth" args.device = torch.device("cpu") args.nsamples = 500 args.out_dir = os.path.join('cachedir', args.exp_name) os.makedirs(args.out_dir, exist_ok=True) args.global_code_ch = args.hidden_dim args.netG_use_noise = True args.test_time = (args.test_time == 1) if not hasattr(args, 'tex_code_dim'): args.tex_code_dim = 256 class Tester(TesterBase): def define_model(self): """Define model """ args = self.args module = importlib.import_module('models.week0417.{}'.format(args.model_name)) self.model = module.AE(args) self.model.to(args.device) self.model.eval() return def draw_color_seg(self, seg): seg = seg.detach().cpu().numpy() color_ = [] for i in range(seg.shape[0]): colori = colorEncode(seg[i].squeeze()) colori = torch.from_numpy(colori / 255.0).float().permute(2, 0, 1) color_.append(colori) color_ = torch.stack(color_) return color_ def to_pil(self, tensor): return transforms.ToPILImage()(tensor.cpu().squeeze().clamp(0.0, 1.0)).convert("RGB") def display_synthesis(self): with st.spinner('Running...'): with torch.no_grad(): grouping_mask = self.model_forward_synthesis(self.data, self.slic, return_type = 'grouping') data = (self.data + 1) / 2.0 seg = grouping_mask.view(-1, 1, args.crop_size, args.crop_size) color_vq = self.draw_color_seg(seg) color_vq = color_vq * 0.8 + data.cpu() * 0.2 st.markdown('

Given the image you chose, our model decomposes the image into ten texture segments, each depicts one kind of texture in the image.

', unsafe_allow_html=True) col1, col2, col3, col4 = st.columns(4) with col1: st.markdown("") with col2: st.markdown("Chosen image") st.image(self.to_pil(data)) with col3: st.markdown("Grouping mask") st.image(self.to_pil(color_vq)) with col4: st.markdown("") seg_onehot = label2one_hot_torch(seg, C = 10) parts = data.cpu() * seg_onehot.squeeze().unsqueeze(1) st.markdown('

We show all texture segments below. To synthesize an arbitrary-sized texture image from a texture segment, choose and click one of the texture segments below.

', unsafe_allow_html=True) tmp_img_list = [] for i in range(parts.shape[0]): part_img = self.to_pil(parts[i]) out_path = 'tmp/{}.png'.format(i) part_img.save(out_path) with open(out_path, "rb") as image: encoded = base64.b64encode(image.read()).decode() tmp_img_list.append(f"data:image/jpeg;base64,{encoded}") tex_idx = clickable_images( tmp_img_list, titles=[f"Group #{str(i)}" for i in range(len(tmp_img_list))], div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"}, img_style={"margin": "5px", "height": "150px"}, key=0 ) if tex_idx > -1: with st.spinner('Running...'): st.markdown('

You can slide the bar below to set the size of the synthesized texture image.

', unsafe_allow_html=True) tex_size = st.slider('', 0, 1000, 256) tex_size = (tex_size // 8) * 8 with torch.no_grad(): tex = self.model_forward_synthesis(self.data, self.slic, tex_idx = tex_idx, tex_size = tex_size, return_type = 'tex') col1, col2, col3, col4 = st.columns([1, 1, 4, 1]) with col1: st.markdown("") with col2: st.markdown("Chosen examplar segment") st.image(self.to_pil(parts[tex_idx])) with col3: st.markdown("Synthesized texture image") st.image(self.to_pil(tex)) with col4: st.markdown("") st.markdown('

You can choose another image from the examplar images on the top and start again!

', unsafe_allow_html=True) def model_forward_synthesis(self, rgb_img, slic, epoch = 1000, test_time = False, test = True, tex_idx = None, tex_size = 256, return_type = 'tex', fill_idx = None, remove_idx = None): args = self.args B, _, imgH, imgW = rgb_img.shape # Encoder: img (B, 3, H, W) -> feature (B, C, imgH//8, imgW//8) conv_feat, _ = self.model.enc(rgb_img) B, C, H, W = conv_feat.shape # Texture code for each superpixel tex_code = self.model.ToTexCode(conv_feat) code = F.interpolate(tex_code, size = (imgH, imgW), mode = 'bilinear', align_corners = False) pool_code = poolfeat(code, slic, avg = True) prop_code, sp_assign, conv_feats = self.model.gcn(pool_code, slic, (args.add_clustering_epoch <= epoch)) softmax = F.softmax(sp_assign * args.temperature, dim = 1) if return_type == 'grouping': return torch.argmax(sp_assign.cpu(), dim = 1) tex_seg = poolfeat(conv_feats, softmax, avg = True) seg = label2one_hot_torch(torch.argmax(softmax, dim = 1).unsqueeze(1), C = softmax.shape[1]) sampled_code = tex_seg[:, tex_idx, :] rec_tex = sampled_code.view(1, -1, 1, 1).repeat(1, 1, tex_size, tex_size) sine_wave = self.model.get_sine_wave(rec_tex, 'rec')[:1] H = tex_size // 8; W = tex_size // 8 noise = torch.randn(B, self.model.sine_wave_dim, H, W).to(tex_code.device) dec_input = torch.cat((sine_wave, noise), dim = 1) #weight = self.model.ChannelWeight(rec_tex) #weight = F.adaptive_avg_pool2d(weight, output_size = (1)).view(weight.shape[0], -1, 1, 1) #weight = torch.sigmoid(weight) #dec_input *= weight rep_rec = self.model.G(dec_input, rec_tex) rep_rec = (rep_rec + 1) / 2.0 return rep_rec def display_editing(self): with st.spinner('Running...'): with torch.no_grad(): grouping_mask = self.model_forward_editing(self.data, self.slic, return_type = 'grouping') data = (self.data + 1) / 2.0 seg = grouping_mask.view(-1, 1, args.crop_size, args.crop_size) color_vq = self.draw_color_seg(seg) color_vq = color_vq * 0.8 + data.cpu() * 0.2 st.markdown('

Given the image you chose, our model decomposes the image into ten texture segments, each depicts one kind of texture in the image.

', unsafe_allow_html=True) col1, col2, col3, col4 = st.columns(4) with col1: st.markdown("") with col2: st.markdown("Chosen image") st.image(self.to_pil(data)) with col3: st.markdown("Grouping mask") st.image(self.to_pil(color_vq)) with col4: st.markdown("") seg_onehot = label2one_hot_torch(seg, C = 10) parts = data.cpu() * seg_onehot.squeeze().unsqueeze(1) st.markdown('

We show all texture segments below.

', unsafe_allow_html=True) tmp_img_list = [] for i in range(parts.shape[0]): part_img = self.to_pil(parts[i]) out_path = 'tmp/{}.png'.format(i) part_img.save(out_path) with open(out_path, "rb") as image: encoded = base64.b64encode(image.read()).decode() tmp_img_list.append(f"data:image/jpeg;base64,{encoded}") tex_idx = clickable_images( tmp_img_list, titles=[f"Group #{str(i)}" for i in range(len(tmp_img_list))], div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"}, img_style={"margin": "5px", "height": "150px"}, key=2 ) st.markdown('

Choose one mask for texture editing.

', unsafe_allow_html=True) mask_list = glob(os.path.join("data/masks/*.png")) byte_mask_list = [] for img_path in mask_list: seg = Image.open(img_path).convert("L") seg = np.asarray(seg) seg = torch.from_numpy(seg).view(1, 1, seg.shape[0], seg.shape[1]) color_vq = self.draw_color_seg(seg) vutils.save_image(color_vq, 'tmp/tmp.png') with open('tmp/tmp.png', "rb") as image: encoded = base64.b64encode(image.read()).decode() byte_mask_list.append(f"data:image/jpeg;base64,{encoded}") img_idx = clickable_images( byte_mask_list, titles=[f"Group #{str(i)}" for i in range(len(byte_mask_list))], div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"}, img_style={"margin": "5px", "height": "150px"}, ) mask_path = mask_list[img_idx] st.markdown('

Choose the texture segment for each group in the given mask below.

', unsafe_allow_html=True) given_mask = Image.open(mask_path).convert("L") given_mask = np.asarray(given_mask) given_mask = torch.from_numpy(given_mask) H, W = given_mask.shape[0], given_mask.shape[1] given_mask = label2one_hot_torch(given_mask.view(1, 1, H, W), C = (given_mask.max()+1)) given_mask = F.interpolate(given_mask, size = (512, 512), mode = 'bilinear', align_corners = False) mask_img_list = [] for i in range(given_mask.shape[1]): part_img = self.to_pil(given_mask[0, i]) out_path = 'tmp/{}.png'.format(i) part_img.save(out_path) with open(out_path, "rb") as image: encoded = base64.b64encode(image.read()).decode() mask_img_list.append(f"data:image/jpeg;base64,{encoded}") part_idx = clickable_images( mask_img_list, div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"}, img_style={"margin": "5px", "height": "150px"}, key=1 ) cols = st.columns(len(mask_img_list)) options = [] for i, col in enumerate(cols): with col: option = st.selectbox( "", ([str(ii) for ii in range(10)]), key = i) options.append(int(option)) print(options) if len(options) > 0: with st.spinner('Running...'): st.markdown('

Edited image is shown below.

', unsafe_allow_html=True) #tex_size = st.slider('', 0, 1000, 256) #tex_size = (tex_size // 8) * 8 with torch.no_grad(): edited = self.model_forward_editing(self.data, self.slic, options=options, given_mask=given_mask, return_type = 'edited') col1, col2, col3 = st.columns([1, 1, 1]) with col1: st.markdown("Input image") img = F.interpolate(self.data, size = edited.shape[-2:], mode = 'bilinear', align_corners = False) st.image(self.to_pil((img + 1) / 2.0)) print(img.shape, edited.shape) with col2: st.markdown("Given mask") seg = Image.open(mask_path).convert("L") seg = np.asarray(seg) seg = torch.from_numpy(seg).view(1, 1, seg.shape[0], seg.shape[1]) color_vq = self.draw_color_seg(seg) color_vq = F.interpolate(color_vq, size = (512, 512), mode = 'bilinear', align_corners = False) st.image(self.to_pil(color_vq)) with col3: st.markdown("Synthesized texture image") st.image(self.to_pil(edited)) st.markdown('

You can choose another image from the examplar images on the top and start again!

', unsafe_allow_html=True) def model_forward_editing(self, rgb_img, slic, epoch = 1000, test_time = False, test = True, tex_idx = None, tex_size = 256, return_type = 'edited', fill_idx = None, remove_idx = None, options = None, given_mask = None): args = self.args B, _, imgH, imgW = rgb_img.shape # Encoder: img (B, 3, H, W) -> feature (B, C, imgH//8, imgW//8) conv_feat, _ = self.model.enc(rgb_img) B, C, H, W = conv_feat.shape # Texture code for each superpixel tex_code = self.model.ToTexCode(conv_feat) code = F.interpolate(tex_code, size = (imgH, imgW), mode = 'bilinear', align_corners = False) pool_code = poolfeat(code, slic, avg = True) prop_code, sp_assign, conv_feats = self.model.gcn(pool_code, slic, (args.add_clustering_epoch <= epoch)) softmax = F.softmax(sp_assign * args.temperature, dim = 1) if return_type == 'grouping': return torch.argmax(sp_assign.cpu(), dim = 1) tex_seg = poolfeat(conv_feats, softmax, avg = True) seg = label2one_hot_torch(torch.argmax(softmax, dim = 1).unsqueeze(1), C = softmax.shape[1]) rec_tex = torch.zeros((1, tex_seg.shape[-1], 512, 512)) for i in range(given_mask.shape[1]): label = options[i] code = tex_seg[0, label, :].view(1, -1, 1, 1).repeat(1, 1, 512, 512) rec_tex += code * given_mask[:, i:i+1] tex_size = 512 sine_wave = self.model.get_sine_wave(rec_tex, 'rec')[:1] H = tex_size // 8; W = tex_size // 8 noise = torch.randn(B, self.model.sine_wave_dim, H, W).to(tex_code.device) dec_input = torch.cat((sine_wave, noise), dim = 1) rep_rec = self.model.G(dec_input, rec_tex) rep_rec = (rep_rec + 1) / 2.0 return rep_rec def load_data(self, data_path): rgb_img = Image.open(data_path) crop_size = self.args.crop_size i = 40; j = 40; h = crop_size; w = crop_size rgb_img = transforms.Resize(size=320)(rgb_img) rgb_img = TF.crop(rgb_img, i, j, h, w) # compute superpixel sp_num = 196 slic_i = slic(np.array(rgb_img), n_segments=sp_num, compactness=10, start_label=0, min_size_factor=0.3) slic_i = torch.from_numpy(slic_i) slic_i[slic_i >= sp_num] = sp_num - 1 oh = label2one_hot_torch(slic_i.unsqueeze(0).unsqueeze(0), C = sp_num).squeeze() self.slic = oh.unsqueeze(0).to(args.device) rgb_img = TF.to_tensor(rgb_img) rgb_img = rgb_img.unsqueeze(0) self.data = rgb_img.to(args.device) * 2 - 1 def load_model(self, model_path): self.model = torch.nn.DataParallel(self.model) cpk = torch.load(model_path, map_location=torch.device('cpu')) saved_state_dict = cpk['model'] self.model.load_state_dict(saved_state_dict) self.model = self.model.module return """ def test(self): #for iteration in tqdm(range(args.nsamples)): self.test_step(0) self.display(0, 'train') """ def main(): #torch.cuda.empty_cache() st.set_page_config(layout="wide") st.markdown(""" """, unsafe_allow_html=True) st.title("Scraping Textures from Natural Images for Synthesis and Editing") #st.markdown("**In this demo, we show how to scrape textures from natural images for texture synthesis and editing.**") st.markdown('

In this demo, we show how to scrape textures from natural images for texture synthesis and editing.

', unsafe_allow_html=True) st.markdown("## Texture synthesis") st.markdown('

Here we provide a set of example images, please choose and click one image to start.

', unsafe_allow_html=True) img_list = glob(os.path.join("data/images/*.jpg")) test_img_list = glob(os.path.join("data/test_images/*.jpg")) img_list.extend(test_img_list) byte_img_list = [] for img_path in img_list: with open(img_path, "rb") as image: encoded = base64.b64encode(image.read()).decode() byte_img_list.append(f"data:image/jpeg;base64,{encoded}") img_idx = clickable_images( byte_img_list, titles=[f"Group #{str(i)}" for i in range(len(byte_img_list))], div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"}, img_style={"margin": "5px", "height": "150px"}, ) img_path = img_list[img_idx] img_name = img_path.split("/")[-1] args.pretrained_path = os.path.join("weights/{}/cpk.pth".format(img_name.split(".")[0])) if img_idx > -1: tester = Tester(args) tester.define_model() tester.load_data(img_path) tester.load_model(args.pretrained_path) app_idx = st.selectbox('Please select between texture synthesis or editing', ["Texture Synthesis", "Texture Editing"]) if app_idx == 'Texture Editing': st.header("Texture Editing") tester.display_editing() else: st.header("Texture Synthesis") tester.display_synthesis() if __name__ == '__main__': os.system("pip install torch-geometric==1.7.2") main()