import os import time import json import base64 import argparse import importlib from glob import glob from PIL import Image from imageio import imsave import torch import torchvision.utils as vutils import sys sys.path.append(".") import numpy as np from libs.test_base import TesterBase from libs.utils import colorEncode, label2one_hot_torch from tqdm import tqdm from libs.options import BaseOptions import torch.nn.functional as F from libs.nnutils import poolfeat, upfeat import streamlit as st from skimage.segmentation import slic import torchvision.transforms.functional as TF import torchvision.transforms as transforms from st_clickable_images import clickable_images args = BaseOptions().gather_options() if args.img_path is not None: args.exp_name = os.path.join(args.exp_name, args.img_path.split('/')[-1].split('.')[0]) args.batch_size = 1 args.data_path = "/home/xli/DATA/BSR_processed/train" args.label_path = "/home/xli/DATA/BSR/BSDS500/data/groundTruth" args.device = torch.device("cpu") args.nsamples = 500 args.out_dir = os.path.join('cachedir', args.exp_name) os.makedirs(args.out_dir, exist_ok=True) args.global_code_ch = args.hidden_dim args.netG_use_noise = True args.test_time = (args.test_time == 1) if not hasattr(args, 'tex_code_dim'): args.tex_code_dim = 256 class Tester(TesterBase): def define_model(self): """Define model """ args = self.args module = importlib.import_module('models.week0417.{}'.format(args.model_name)) self.model = module.AE(args) self.model.to(args.device) self.model.eval() return def draw_color_seg(self, seg): seg = seg.detach().cpu().numpy() color_ = [] for i in range(seg.shape[0]): colori = colorEncode(seg[i].squeeze()) colori = torch.from_numpy(colori / 255.0).float().permute(2, 0, 1) color_.append(colori) color_ = torch.stack(color_) return color_ def to_pil(self, tensor): return transforms.ToPILImage()(tensor.cpu().squeeze().clamp(0.0, 1.0)).convert("RGB") def display(self): with st.spinner('Running...'): with torch.no_grad(): grouping_mask = self.model_forward(self.data, self.slic, return_type = 'grouping') data = (self.data + 1) / 2.0 seg = grouping_mask.view(-1, 1, args.crop_size, args.crop_size) color_vq = self.draw_color_seg(seg) color_vq = color_vq * 0.8 + data.cpu() * 0.2 st.markdown('

Given the image you chose, our model decomposes the image into ten texture segments, each depicts one kind of texture in the image.

', unsafe_allow_html=True) col1, col2, col3, col4 = st.columns(4) with col1: st.markdown("") with col2: st.markdown("Chosen image") st.image(self.to_pil(data)) with col3: st.markdown("Grouping mask") st.image(self.to_pil(color_vq)) with col4: st.markdown("") seg_onehot = label2one_hot_torch(seg, C = 10) parts = data.cpu() * seg_onehot.squeeze().unsqueeze(1) st.markdown('

We show all texture segments below. To synthesize an arbitrary-sized texture image from a texture segment, choose and click one of the texture segments below.

', unsafe_allow_html=True) tmp_img_list = [] for i in range(parts.shape[0]): part_img = self.to_pil(parts[i]) out_path = 'tmp/{}.png'.format(i) part_img.save(out_path) with open(out_path, "rb") as image: encoded = base64.b64encode(image.read()).decode() tmp_img_list.append(f"data:image/jpeg;base64,{encoded}") tex_idx = clickable_images( tmp_img_list, titles=[f"Group #{str(i)}" for i in range(len(tmp_img_list))], div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"}, img_style={"margin": "5px", "height": "150px"}, key=0 ) if tex_idx > -1: with st.spinner('Running...'): st.markdown('

You can slide the bar below to set the size of the synthesized texture image.

', unsafe_allow_html=True) tex_size = st.slider('', 0, 1000, 256) tex_size = (tex_size // 8) * 8 with torch.no_grad(): tex = self.model_forward(self.data, self.slic, tex_idx = tex_idx, tex_size = tex_size, return_type = 'tex') col1, col2, col3, col4 = st.columns([1, 1, 4, 1]) with col1: st.markdown("") with col2: st.markdown("Chosen examplar segment") st.image(self.to_pil(parts[tex_idx])) with col3: st.markdown("Synthesized texture image") st.image(self.to_pil(tex)) with col4: st.markdown("") st.markdown('

You can choose another image from the examplar images on the top and start again!

', unsafe_allow_html=True) #torch.cuda.empty_cache() """ st.markdown("#### Texture Editing") st.markdown("**Choose one texture segment to remove.**") remove_idx = clickable_images( tmp_img_list, titles=[f"Group #{str(i)}" for i in range(len(tmp_img_list))], div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"}, img_style={"margin": "5px", "height": "120px"}, key=1 ) st.markdown("**Choose one texture segment to fill in the missing pixels.**") fill_idx = clickable_images( tmp_img_list, titles=[f"Group #{str(i)}" for i in range(len(tmp_img_list))], div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"}, img_style={"margin": "5px", "height": "120px"}, key=2 ) rec = self.model_forward(self.data, self.slic, return_type = 'editing', fill_idx = fill_idx, remove_idx = remove_idx) st.image(self.to_pil(rec)) """ def model_forward(self, rgb_img, slic, epoch = 1000, test_time = False, test = True, tex_idx = None, tex_size = 256, return_type = 'tex', fill_idx = None, remove_idx = None): args = self.args B, _, imgH, imgW = rgb_img.shape # Encoder: img (B, 3, H, W) -> feature (B, C, imgH//8, imgW//8) conv_feat, _ = self.model.enc(rgb_img) B, C, H, W = conv_feat.shape # Texture code for each superpixel tex_code = self.model.ToTexCode(conv_feat) code = F.interpolate(tex_code, size = (imgH, imgW), mode = 'bilinear', align_corners = False) pool_code = poolfeat(code, slic, avg = True) prop_code, sp_assign, conv_feats = self.model.gcn(pool_code, slic, (args.add_clustering_epoch <= epoch)) softmax = F.softmax(sp_assign * args.temperature, dim = 1) if return_type == 'grouping': return torch.argmax(sp_assign.cpu(), dim = 1) tex_seg = poolfeat(conv_feats, softmax, avg = True) seg = label2one_hot_torch(torch.argmax(softmax, dim = 1).unsqueeze(1), C = softmax.shape[1]) if return_type == 'tex': sampled_code = tex_seg[:, tex_idx, :] rec_tex = sampled_code.view(1, -1, 1, 1).repeat(1, 1, tex_size, tex_size) sine_wave = self.model.get_sine_wave(rec_tex, 'rec') H = tex_size // 8; W = tex_size // 8 noise = torch.randn(B, self.model.sine_wave_dim, H, W).to(tex_code.device) dec_input = torch.cat((sine_wave, noise), dim = 1) weight = self.model.ChannelWeight(rec_tex) weight = F.adaptive_avg_pool2d(weight, output_size = (1)).view(weight.shape[0], -1, 1, 1) weight = torch.sigmoid(weight) dec_input *= weight rep_rec = self.model.G(dec_input, rec_tex) rep_rec = (rep_rec + 1) / 2.0 return rep_rec elif return_type == 'editing': remove_mask = 0 fill_mask = 1 rec_tex = upfeat(tex_seg, seg) remove_mask = seg[:, remove_idx:remove_idx+1] fill_tex = tex_seg[:, fill_idx, :].view(1, -1, 1, 1).repeat(1, 1, imgH, imgW) rec_tex = rec_tex * (1 - remove_mask) + fill_tex * remove_mask sine_wave = self.model.get_sine_wave(rec_tex, 'rec') H = imgH // 8; W = imgW // 8 noise = torch.randn(B, self.model.sine_wave_dim, H, W).to(tex_code.device) dec_input = torch.cat((sine_wave, noise), dim = 1) weight = self.model.ChannelWeight(rec_tex) weight = F.adaptive_avg_pool2d(weight, output_size = (1)).view(weight.shape[0], -1, 1, 1) weight = torch.sigmoid(weight) dec_input *= weight rep_rec = self.model.G(dec_input, rec_tex) rep_rec = (rep_rec + 1) / 2.0 return rep_rec def load_data(self, data_path): rgb_img = Image.open(data_path) crop_size = self.args.crop_size i = 40; j = 40; h = crop_size; w = crop_size rgb_img = TF.crop(rgb_img, i, j, h, w) # compute superpixel sp_num = 196 slic_i = slic(np.array(rgb_img), n_segments=sp_num, compactness=10, start_label=0, min_size_factor=0.3) slic_i = torch.from_numpy(slic_i) slic_i[slic_i >= sp_num] = sp_num - 1 oh = label2one_hot_torch(slic_i.unsqueeze(0).unsqueeze(0), C = sp_num).squeeze() self.slic = oh.unsqueeze(0).to(args.device) rgb_img = TF.to_tensor(rgb_img) rgb_img = rgb_img.unsqueeze(0) self.data = rgb_img.to(args.device) * 2 - 1 def load_model(self, model_path): self.model = torch.nn.DataParallel(self.model) cpk = torch.load(model_path, map_location=torch.device('cpu')) saved_state_dict = cpk['model'] self.model.load_state_dict(saved_state_dict) self.model = self.model.module return def test(self): """ Test function """ #for iteration in tqdm(range(args.nsamples)): self.test_step(0) self.display(0, 'train') def main(): #torch.cuda.empty_cache() st.set_page_config(layout="wide") st.markdown(""" """, unsafe_allow_html=True) st.title("Scraping Textures from Natural Images for Synthesis and Editing") #st.markdown("**In this demo, we show how to scrape textures from natural images for texture synthesis and editing.**") st.markdown('

In this demo, we show how to scrape textures from natural images for texture synthesis and editing.

', unsafe_allow_html=True) st.markdown("## Texture synthesis") st.markdown('

Here we provide a set of example images, please choose and click one image to start.

', unsafe_allow_html=True) img_list = glob(os.path.join("data/images/*.jpg")) test_img_list = glob(os.path.join("data/test_images/*.jpg")) img_list.extend(test_img_list) byte_img_list = [] for img_path in img_list: with open(img_path, "rb") as image: encoded = base64.b64encode(image.read()).decode() byte_img_list.append(f"data:image/jpeg;base64,{encoded}") img_idx = clickable_images( byte_img_list, titles=[f"Group #{str(i)}" for i in range(len(byte_img_list))], div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"}, img_style={"margin": "5px", "height": "150px"}, ) img_path = img_list[img_idx] img_name = img_path.split("/")[-1] args.pretrained_path = os.path.join("weights/{}/cpk.pth".format(img_name.split(".")[0])) if img_idx > -1: tester = Tester(args) tester.define_model() tester.load_data(img_path) tester.load_model(args.pretrained_path) tester.display() if __name__ == '__main__': os.system("pip install torch-geometric==1.7.2") main()