TextureScraping / app.py
sunshineatnoon
new_model
827b81f
import os
import time
import json
import base64
import argparse
import importlib
from glob import glob
from PIL import Image
from imageio import imsave
import torch
import torchvision.utils as vutils
import sys
sys.path.append(".")
import numpy as np
from libs.test_base import TesterBase
from libs.utils import colorEncode, label2one_hot_torch
from tqdm import tqdm
from libs.options import BaseOptions
import torch.nn.functional as F
from libs.nnutils import poolfeat, upfeat
import streamlit as st
from skimage.segmentation import slic
import torchvision.transforms.functional as TF
import torchvision.transforms as transforms
from st_clickable_images import clickable_images
args = BaseOptions().gather_options()
if args.img_path is not None:
args.exp_name = os.path.join(args.exp_name, args.img_path.split('/')[-1].split('.')[0])
args.batch_size = 1
args.data_path = "/home/xli/DATA/BSR_processed/train"
args.label_path = "/home/xli/DATA/BSR/BSDS500/data/groundTruth"
args.device = torch.device("cpu")
args.nsamples = 500
args.out_dir = os.path.join('cachedir', args.exp_name)
os.makedirs(args.out_dir, exist_ok=True)
args.global_code_ch = args.hidden_dim
args.netG_use_noise = True
args.test_time = (args.test_time == 1)
if not hasattr(args, 'tex_code_dim'):
args.tex_code_dim = 256
class Tester(TesterBase):
def define_model(self):
"""Define model
"""
args = self.args
module = importlib.import_module('models.week0417.{}'.format(args.model_name))
self.model = module.AE(args)
self.model.to(args.device)
self.model.eval()
return
def draw_color_seg(self, seg):
seg = seg.detach().cpu().numpy()
color_ = []
for i in range(seg.shape[0]):
colori = colorEncode(seg[i].squeeze())
colori = torch.from_numpy(colori / 255.0).float().permute(2, 0, 1)
color_.append(colori)
color_ = torch.stack(color_)
return color_
def to_pil(self, tensor):
return transforms.ToPILImage()(tensor.cpu().squeeze().clamp(0.0, 1.0)).convert("RGB")
def display_synthesis(self):
with st.spinner('Running...'):
with torch.no_grad():
grouping_mask = self.model_forward_synthesis(self.data, self.slic, return_type = 'grouping')
data = (self.data + 1) / 2.0
seg = grouping_mask.view(-1, 1, args.crop_size, args.crop_size)
color_vq = self.draw_color_seg(seg)
color_vq = color_vq * 0.8 + data.cpu() * 0.2
st.markdown('<p class="big-font">Given the image you chose, our model decomposes the image into ten texture segments, each depicts one kind of texture in the image.</p>', unsafe_allow_html=True)
col1, col2, col3, col4 = st.columns(4)
with col1:
st.markdown("")
with col2:
st.markdown("Chosen image")
st.image(self.to_pil(data))
with col3:
st.markdown("Grouping mask")
st.image(self.to_pil(color_vq))
with col4:
st.markdown("")
seg_onehot = label2one_hot_torch(seg, C = 10)
parts = data.cpu() * seg_onehot.squeeze().unsqueeze(1)
st.markdown('<p class="big-font">We show all texture segments below. To synthesize an arbitrary-sized texture image from a texture segment, choose and click one of the texture segments below.</p>', unsafe_allow_html=True)
tmp_img_list = []
for i in range(parts.shape[0]):
part_img = self.to_pil(parts[i])
out_path = 'tmp/{}.png'.format(i)
part_img.save(out_path)
with open(out_path, "rb") as image:
encoded = base64.b64encode(image.read()).decode()
tmp_img_list.append(f"data:image/jpeg;base64,{encoded}")
tex_idx = clickable_images(
tmp_img_list,
titles=[f"Group #{str(i)}" for i in range(len(tmp_img_list))],
div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"},
img_style={"margin": "5px", "height": "150px"},
key=0
)
if tex_idx > -1:
with st.spinner('Running...'):
st.markdown('<p class="big-font">You can slide the bar below to set the size of the synthesized texture image.</p>', unsafe_allow_html=True)
tex_size = st.slider('', 0, 1000, 256)
tex_size = (tex_size // 8) * 8
with torch.no_grad():
tex = self.model_forward_synthesis(self.data, self.slic, tex_idx = tex_idx, tex_size = tex_size, return_type = 'tex')
col1, col2, col3, col4 = st.columns([1, 1, 4, 1])
with col1:
st.markdown("")
with col2:
st.markdown("Chosen examplar segment")
st.image(self.to_pil(parts[tex_idx]))
with col3:
st.markdown("Synthesized texture image")
st.image(self.to_pil(tex))
with col4:
st.markdown("")
st.markdown('<p class="big-font">You can choose another image from the examplar images on the top and start again!</p>', unsafe_allow_html=True)
def model_forward_synthesis(self, rgb_img, slic, epoch = 1000, test_time = False,
test = True, tex_idx = None, tex_size = 256,
return_type = 'tex', fill_idx = None, remove_idx = None):
args = self.args
B, _, imgH, imgW = rgb_img.shape
# Encoder: img (B, 3, H, W) -> feature (B, C, imgH//8, imgW//8)
conv_feat, _ = self.model.enc(rgb_img)
B, C, H, W = conv_feat.shape
# Texture code for each superpixel
tex_code = self.model.ToTexCode(conv_feat)
code = F.interpolate(tex_code, size = (imgH, imgW), mode = 'bilinear', align_corners = False)
pool_code = poolfeat(code, slic, avg = True)
prop_code, sp_assign, conv_feats = self.model.gcn(pool_code, slic, (args.add_clustering_epoch <= epoch))
softmax = F.softmax(sp_assign * args.temperature, dim = 1)
if return_type == 'grouping':
return torch.argmax(sp_assign.cpu(), dim = 1)
tex_seg = poolfeat(conv_feats, softmax, avg = True)
seg = label2one_hot_torch(torch.argmax(softmax, dim = 1).unsqueeze(1), C = softmax.shape[1])
sampled_code = tex_seg[:, tex_idx, :]
rec_tex = sampled_code.view(1, -1, 1, 1).repeat(1, 1, tex_size, tex_size)
sine_wave = self.model.get_sine_wave(rec_tex, 'rec')[:1]
H = tex_size // 8; W = tex_size // 8
noise = torch.randn(B, self.model.sine_wave_dim, H, W).to(tex_code.device)
dec_input = torch.cat((sine_wave, noise), dim = 1)
#weight = self.model.ChannelWeight(rec_tex)
#weight = F.adaptive_avg_pool2d(weight, output_size = (1)).view(weight.shape[0], -1, 1, 1)
#weight = torch.sigmoid(weight)
#dec_input *= weight
rep_rec = self.model.G(dec_input, rec_tex)
rep_rec = (rep_rec + 1) / 2.0
return rep_rec
def display_editing(self):
with st.spinner('Running...'):
with torch.no_grad():
grouping_mask = self.model_forward_editing(self.data, self.slic, return_type = 'grouping')
data = (self.data + 1) / 2.0
seg = grouping_mask.view(-1, 1, args.crop_size, args.crop_size)
color_vq = self.draw_color_seg(seg)
color_vq = color_vq * 0.8 + data.cpu() * 0.2
st.markdown('<p class="big-font">Given the image you chose, our model decomposes the image into ten texture segments, each depicts one kind of texture in the image.</p>', unsafe_allow_html=True)
col1, col2, col3, col4 = st.columns(4)
with col1:
st.markdown("")
with col2:
st.markdown("Chosen image")
st.image(self.to_pil(data))
with col3:
st.markdown("Grouping mask")
st.image(self.to_pil(color_vq))
with col4:
st.markdown("")
seg_onehot = label2one_hot_torch(seg, C = 10)
parts = data.cpu() * seg_onehot.squeeze().unsqueeze(1)
st.markdown('<p class="big-font">We show all texture segments below.</p>', unsafe_allow_html=True)
tmp_img_list = []
for i in range(parts.shape[0]):
part_img = self.to_pil(parts[i])
out_path = 'tmp/{}.png'.format(i)
part_img.save(out_path)
with open(out_path, "rb") as image:
encoded = base64.b64encode(image.read()).decode()
tmp_img_list.append(f"data:image/jpeg;base64,{encoded}")
tex_idx = clickable_images(
tmp_img_list,
titles=[f"Group #{str(i)}" for i in range(len(tmp_img_list))],
div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"},
img_style={"margin": "5px", "height": "150px"},
key=2
)
st.markdown('<p class="big-font">Choose one mask for texture editing.</p>', unsafe_allow_html=True)
mask_list = glob(os.path.join("data/masks/*.png"))
byte_mask_list = []
for img_path in mask_list:
seg = Image.open(img_path).convert("L")
seg = np.asarray(seg)
seg = torch.from_numpy(seg).view(1, 1, seg.shape[0], seg.shape[1])
color_vq = self.draw_color_seg(seg)
vutils.save_image(color_vq, 'tmp/tmp.png')
with open('tmp/tmp.png', "rb") as image:
encoded = base64.b64encode(image.read()).decode()
byte_mask_list.append(f"data:image/jpeg;base64,{encoded}")
img_idx = clickable_images(
byte_mask_list,
titles=[f"Group #{str(i)}" for i in range(len(byte_mask_list))],
div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"},
img_style={"margin": "5px", "height": "150px"},
)
mask_path = mask_list[img_idx]
st.markdown('<p class="big-font">Choose the texture segment for each group in the given mask below.</p>', unsafe_allow_html=True)
given_mask = Image.open(mask_path).convert("L")
given_mask = np.asarray(given_mask)
given_mask = torch.from_numpy(given_mask)
H, W = given_mask.shape[0], given_mask.shape[1]
given_mask = label2one_hot_torch(given_mask.view(1, 1, H, W), C = (given_mask.max()+1))
given_mask = F.interpolate(given_mask, size = (512, 512), mode = 'bilinear', align_corners = False)
mask_img_list = []
for i in range(given_mask.shape[1]):
part_img = self.to_pil(given_mask[0, i])
out_path = 'tmp/{}.png'.format(i)
part_img.save(out_path)
with open(out_path, "rb") as image:
encoded = base64.b64encode(image.read()).decode()
mask_img_list.append(f"data:image/jpeg;base64,{encoded}")
part_idx = clickable_images(
mask_img_list,
div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"},
img_style={"margin": "5px", "height": "150px"},
key=1
)
cols = st.columns(len(mask_img_list))
options = []
for i, col in enumerate(cols):
with col:
option = st.selectbox(
"",
([str(ii) for ii in range(10)]),
key = i)
options.append(int(option))
print(options)
if len(options) > 0:
with st.spinner('Running...'):
st.markdown('<p class="big-font">Edited image is shown below.</p>', unsafe_allow_html=True)
#tex_size = st.slider('', 0, 1000, 256)
#tex_size = (tex_size // 8) * 8
with torch.no_grad():
edited = self.model_forward_editing(self.data, self.slic, options=options, given_mask=given_mask, return_type = 'edited')
col1, col2, col3 = st.columns([1, 1, 1])
with col1:
st.markdown("Input image")
img = F.interpolate(self.data, size = edited.shape[-2:], mode = 'bilinear', align_corners = False)
st.image(self.to_pil((img + 1) / 2.0))
print(img.shape, edited.shape)
with col2:
st.markdown("Given mask")
seg = Image.open(mask_path).convert("L")
seg = np.asarray(seg)
seg = torch.from_numpy(seg).view(1, 1, seg.shape[0], seg.shape[1])
color_vq = self.draw_color_seg(seg)
color_vq = F.interpolate(color_vq, size = (512, 512), mode = 'bilinear', align_corners = False)
st.image(self.to_pil(color_vq))
with col3:
st.markdown("Synthesized texture image")
st.image(self.to_pil(edited))
st.markdown('<p class="big-font">You can choose another image from the examplar images on the top and start again!</p>', unsafe_allow_html=True)
def model_forward_editing(self, rgb_img, slic, epoch = 1000, test_time = False,
test = True, tex_idx = None, tex_size = 256,
return_type = 'edited', fill_idx = None, remove_idx = None,
options = None, given_mask = None):
args = self.args
B, _, imgH, imgW = rgb_img.shape
# Encoder: img (B, 3, H, W) -> feature (B, C, imgH//8, imgW//8)
conv_feat, _ = self.model.enc(rgb_img)
B, C, H, W = conv_feat.shape
# Texture code for each superpixel
tex_code = self.model.ToTexCode(conv_feat)
code = F.interpolate(tex_code, size = (imgH, imgW), mode = 'bilinear', align_corners = False)
pool_code = poolfeat(code, slic, avg = True)
prop_code, sp_assign, conv_feats = self.model.gcn(pool_code, slic, (args.add_clustering_epoch <= epoch))
softmax = F.softmax(sp_assign * args.temperature, dim = 1)
if return_type == 'grouping':
return torch.argmax(sp_assign.cpu(), dim = 1)
tex_seg = poolfeat(conv_feats, softmax, avg = True)
seg = label2one_hot_torch(torch.argmax(softmax, dim = 1).unsqueeze(1), C = softmax.shape[1])
rec_tex = torch.zeros((1, tex_seg.shape[-1], 512, 512))
for i in range(given_mask.shape[1]):
label = options[i]
code = tex_seg[0, label, :].view(1, -1, 1, 1).repeat(1, 1, 512, 512)
rec_tex += code * given_mask[:, i:i+1]
tex_size = 512
sine_wave = self.model.get_sine_wave(rec_tex, 'rec')[:1]
H = tex_size // 8; W = tex_size // 8
noise = torch.randn(B, self.model.sine_wave_dim, H, W).to(tex_code.device)
dec_input = torch.cat((sine_wave, noise), dim = 1)
rep_rec = self.model.G(dec_input, rec_tex)
rep_rec = (rep_rec + 1) / 2.0
return rep_rec
def load_data(self, data_path):
rgb_img = Image.open(data_path)
crop_size = self.args.crop_size
i = 40; j = 40; h = crop_size; w = crop_size
rgb_img = transforms.Resize(size=320)(rgb_img)
rgb_img = TF.crop(rgb_img, i, j, h, w)
# compute superpixel
sp_num = 196
slic_i = slic(np.array(rgb_img), n_segments=sp_num, compactness=10, start_label=0, min_size_factor=0.3)
slic_i = torch.from_numpy(slic_i)
slic_i[slic_i >= sp_num] = sp_num - 1
oh = label2one_hot_torch(slic_i.unsqueeze(0).unsqueeze(0), C = sp_num).squeeze()
self.slic = oh.unsqueeze(0).to(args.device)
rgb_img = TF.to_tensor(rgb_img)
rgb_img = rgb_img.unsqueeze(0)
self.data = rgb_img.to(args.device) * 2 - 1
def load_model(self, model_path):
self.model = torch.nn.DataParallel(self.model)
cpk = torch.load(model_path, map_location=torch.device('cpu'))
saved_state_dict = cpk['model']
self.model.load_state_dict(saved_state_dict)
self.model = self.model.module
return
"""
def test(self):
#for iteration in tqdm(range(args.nsamples)):
self.test_step(0)
self.display(0, 'train')
"""
def main():
#torch.cuda.empty_cache()
st.set_page_config(layout="wide")
st.markdown("""
<style>
.big-font {
font-size:30px !important;
}
</style>
""", unsafe_allow_html=True)
st.title("Scraping Textures from Natural Images for Synthesis and Editing")
#st.markdown("**In this demo, we show how to scrape textures from natural images for texture synthesis and editing.**")
st.markdown('<p class="big-font">In this demo, we show how to scrape textures from natural images for texture synthesis and editing.</p>', unsafe_allow_html=True)
st.markdown("## Texture synthesis")
st.markdown('<p class="big-font">Here we provide a set of example images, please choose and click one image to start.</p>', unsafe_allow_html=True)
img_list = glob(os.path.join("data/images/*.jpg"))
test_img_list = glob(os.path.join("data/test_images/*.jpg"))
img_list.extend(test_img_list)
byte_img_list = []
for img_path in img_list:
with open(img_path, "rb") as image:
encoded = base64.b64encode(image.read()).decode()
byte_img_list.append(f"data:image/jpeg;base64,{encoded}")
img_idx = clickable_images(
byte_img_list,
titles=[f"Group #{str(i)}" for i in range(len(byte_img_list))],
div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"},
img_style={"margin": "5px", "height": "150px"},
)
img_path = img_list[img_idx]
img_name = img_path.split("/")[-1]
args.pretrained_path = os.path.join("weights/{}/cpk.pth".format(img_name.split(".")[0]))
if img_idx > -1:
tester = Tester(args)
tester.define_model()
tester.load_data(img_path)
tester.load_model(args.pretrained_path)
app_idx = st.selectbox('Please select between texture synthesis or editing',
["Texture Synthesis", "Texture Editing"])
if app_idx == 'Texture Editing':
st.header("Texture Editing")
tester.display_editing()
else:
st.header("Texture Synthesis")
tester.display_synthesis()
if __name__ == '__main__':
os.system("pip install torch-geometric==1.7.2")
main()