TextureScraping / app.py
sunshineatnoon's picture
Update app.py
3bfa730
raw
history blame
12.5 kB
import os
import time
import json
import base64
import argparse
import importlib
from glob import glob
from PIL import Image
from imageio import imsave
import torch
import torchvision.utils as vutils
import sys
sys.path.append(".")
import numpy as np
from libs.test_base import TesterBase
from libs.utils import colorEncode, label2one_hot_torch
from tqdm import tqdm
from libs.options import BaseOptions
import torch.nn.functional as F
from libs.nnutils import poolfeat, upfeat
import streamlit as st
from skimage.segmentation import slic
import torchvision.transforms.functional as TF
import torchvision.transforms as transforms
from st_clickable_images import clickable_images
args = BaseOptions().gather_options()
if args.img_path is not None:
args.exp_name = os.path.join(args.exp_name, args.img_path.split('/')[-1].split('.')[0])
args.batch_size = 1
args.data_path = "/home/xli/DATA/BSR_processed/train"
args.label_path = "/home/xli/DATA/BSR/BSDS500/data/groundTruth"
args.device = torch.device("cpu")
args.nsamples = 500
args.out_dir = os.path.join('cachedir', args.exp_name)
os.makedirs(args.out_dir, exist_ok=True)
args.global_code_ch = args.hidden_dim
args.netG_use_noise = True
args.test_time = (args.test_time == 1)
if not hasattr(args, 'tex_code_dim'):
args.tex_code_dim = 256
class Tester(TesterBase):
def define_model(self):
"""Define model
"""
args = self.args
module = importlib.import_module('models.week0417.{}'.format(args.model_name))
self.model = module.AE(args)
self.model.to(args.device)
self.model.eval()
return
def draw_color_seg(self, seg):
seg = seg.detach().cpu().numpy()
color_ = []
for i in range(seg.shape[0]):
colori = colorEncode(seg[i].squeeze())
colori = torch.from_numpy(colori / 255.0).float().permute(2, 0, 1)
color_.append(colori)
color_ = torch.stack(color_)
return color_
def to_pil(self, tensor):
return transforms.ToPILImage()(tensor.cpu().squeeze().clamp(0.0, 1.0)).convert("RGB")
def display(self):
with st.spinner('Running...'):
with torch.no_grad():
grouping_mask = self.model_forward(self.data, self.slic, return_type = 'grouping')
data = (self.data + 1) / 2.0
seg = grouping_mask.view(-1, 1, args.crop_size, args.crop_size)
color_vq = self.draw_color_seg(seg)
color_vq = color_vq * 0.8 + data.cpu() * 0.2
st.markdown('<p class="big-font">Given the image you chose, our model decomposes the image into ten texture segments, each depicts one kind of texture in the image.</p>', unsafe_allow_html=True)
col1, col2, col3, col4 = st.columns(4)
with col1:
st.markdown("")
with col2:
st.markdown("Chosen image")
st.image(self.to_pil(data))
with col3:
st.markdown("Grouping mask")
st.image(self.to_pil(color_vq))
with col4:
st.markdown("")
seg_onehot = label2one_hot_torch(seg, C = 10)
parts = data.cpu() * seg_onehot.squeeze().unsqueeze(1)
st.markdown('<p class="big-font">We show all texture segments below. To synthesize an arbitrary-sized texture image from a texture segment, choose and click one of the texture segments below.</p>', unsafe_allow_html=True)
tmp_img_list = []
for i in range(parts.shape[0]):
part_img = self.to_pil(parts[i])
out_path = 'tmp/{}.png'.format(i)
part_img.save(out_path)
with open(out_path, "rb") as image:
encoded = base64.b64encode(image.read()).decode()
tmp_img_list.append(f"data:image/jpeg;base64,{encoded}")
tex_idx = clickable_images(
tmp_img_list,
titles=[f"Group #{str(i)}" for i in range(len(tmp_img_list))],
div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"},
img_style={"margin": "5px", "height": "150px"},
key=0
)
if tex_idx > -1:
with st.spinner('Running...'):
st.markdown('<p class="big-font">You can slide the bar below to set the size of the synthesized texture image.</p>', unsafe_allow_html=True)
tex_size = st.slider('', 0, 1000, 256)
tex_size = (tex_size // 8) * 8
with torch.no_grad():
tex = self.model_forward(self.data, self.slic, tex_idx = tex_idx, tex_size = tex_size, return_type = 'tex')
col1, col2, col3, col4 = st.columns([1, 1, 4, 1])
with col1:
st.markdown("")
with col2:
st.markdown("Chosen examplar segment")
st.image(self.to_pil(parts[tex_idx]))
with col3:
st.markdown("Synthesized texture image")
st.image(self.to_pil(tex))
with col4:
st.markdown("")
st.markdown('<p class="big-font">You can choose another image from the examplar images on the top and start again!</p>', unsafe_allow_html=True)
#torch.cuda.empty_cache()
"""
st.markdown("#### Texture Editing")
st.markdown("**Choose one texture segment to remove.**")
remove_idx = clickable_images(
tmp_img_list,
titles=[f"Group #{str(i)}" for i in range(len(tmp_img_list))],
div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"},
img_style={"margin": "5px", "height": "120px"},
key=1
)
st.markdown("**Choose one texture segment to fill in the missing pixels.**")
fill_idx = clickable_images(
tmp_img_list,
titles=[f"Group #{str(i)}" for i in range(len(tmp_img_list))],
div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"},
img_style={"margin": "5px", "height": "120px"},
key=2
)
rec = self.model_forward(self.data, self.slic, return_type = 'editing', fill_idx = fill_idx, remove_idx = remove_idx)
st.image(self.to_pil(rec))
"""
def model_forward(self, rgb_img, slic, epoch = 1000, test_time = False,
test = True, tex_idx = None, tex_size = 256,
return_type = 'tex', fill_idx = None, remove_idx = None):
args = self.args
B, _, imgH, imgW = rgb_img.shape
# Encoder: img (B, 3, H, W) -> feature (B, C, imgH//8, imgW//8)
conv_feat, _ = self.model.enc(rgb_img)
B, C, H, W = conv_feat.shape
# Texture code for each superpixel
tex_code = self.model.ToTexCode(conv_feat)
code = F.interpolate(tex_code, size = (imgH, imgW), mode = 'bilinear', align_corners = False)
pool_code = poolfeat(code, slic, avg = True)
prop_code, sp_assign, conv_feats = self.model.gcn(pool_code, slic, (args.add_clustering_epoch <= epoch))
softmax = F.softmax(sp_assign * args.temperature, dim = 1)
if return_type == 'grouping':
return torch.argmax(sp_assign.cpu(), dim = 1)
tex_seg = poolfeat(conv_feats, softmax, avg = True)
seg = label2one_hot_torch(torch.argmax(softmax, dim = 1).unsqueeze(1), C = softmax.shape[1])
if return_type == 'tex':
sampled_code = tex_seg[:, tex_idx, :]
rec_tex = sampled_code.view(1, -1, 1, 1).repeat(1, 1, tex_size, tex_size)
sine_wave = self.model.get_sine_wave(rec_tex, 'rec')
H = tex_size // 8; W = tex_size // 8
noise = torch.randn(B, self.model.sine_wave_dim, H, W).to(tex_code.device)
dec_input = torch.cat((sine_wave, noise), dim = 1)
weight = self.model.ChannelWeight(rec_tex)
weight = F.adaptive_avg_pool2d(weight, output_size = (1)).view(weight.shape[0], -1, 1, 1)
weight = torch.sigmoid(weight)
dec_input *= weight
rep_rec = self.model.G(dec_input, rec_tex)
rep_rec = (rep_rec + 1) / 2.0
return rep_rec
elif return_type == 'editing':
remove_mask = 0
fill_mask = 1
rec_tex = upfeat(tex_seg, seg)
remove_mask = seg[:, remove_idx:remove_idx+1]
fill_tex = tex_seg[:, fill_idx, :].view(1, -1, 1, 1).repeat(1, 1, imgH, imgW)
rec_tex = rec_tex * (1 - remove_mask) + fill_tex * remove_mask
sine_wave = self.model.get_sine_wave(rec_tex, 'rec')
H = imgH // 8; W = imgW // 8
noise = torch.randn(B, self.model.sine_wave_dim, H, W).to(tex_code.device)
dec_input = torch.cat((sine_wave, noise), dim = 1)
weight = self.model.ChannelWeight(rec_tex)
weight = F.adaptive_avg_pool2d(weight, output_size = (1)).view(weight.shape[0], -1, 1, 1)
weight = torch.sigmoid(weight)
dec_input *= weight
rep_rec = self.model.G(dec_input, rec_tex)
rep_rec = (rep_rec + 1) / 2.0
return rep_rec
def load_data(self, data_path):
rgb_img = Image.open(data_path)
crop_size = self.args.crop_size
i = 40; j = 40; h = crop_size; w = crop_size
rgb_img = TF.crop(rgb_img, i, j, h, w)
# compute superpixel
sp_num = 196
slic_i = slic(np.array(rgb_img), n_segments=sp_num, compactness=10, start_label=0, min_size_factor=0.3)
slic_i = torch.from_numpy(slic_i)
slic_i[slic_i >= sp_num] = sp_num - 1
oh = label2one_hot_torch(slic_i.unsqueeze(0).unsqueeze(0), C = sp_num).squeeze()
self.slic = oh.unsqueeze(0).to(args.device)
rgb_img = TF.to_tensor(rgb_img)
rgb_img = rgb_img.unsqueeze(0)
self.data = rgb_img.to(args.device) * 2 - 1
def load_model(self, model_path):
self.model = torch.nn.DataParallel(self.model)
cpk = torch.load(model_path, map_location=torch.device('cpu'))
saved_state_dict = cpk['model']
self.model.load_state_dict(saved_state_dict)
self.model = self.model.module
return
def test(self):
""" Test function
"""
#for iteration in tqdm(range(args.nsamples)):
self.test_step(0)
self.display(0, 'train')
def main():
#torch.cuda.empty_cache()
st.set_page_config(layout="wide")
st.markdown("""
<style>
.big-font {
font-size:30px !important;
}
</style>
""", unsafe_allow_html=True)
st.title("Scraping Textures from Natural Images for Synthesis and Editing")
#st.markdown("**In this demo, we show how to scrape textures from natural images for texture synthesis and editing.**")
st.markdown('<p class="big-font">In this demo, we show how to scrape textures from natural images for texture synthesis and editing.</p>', unsafe_allow_html=True)
st.markdown("## Texture synthesis")
st.markdown('<p class="big-font">Here we provide a set of example images, please choose and click one image to start.</p>', unsafe_allow_html=True)
img_list = glob(os.path.join("data/images/*.jpg"))
test_img_list = glob(os.path.join("data/test_images/*.jpg"))
img_list.extend(test_img_list)
byte_img_list = []
for img_path in img_list:
with open(img_path, "rb") as image:
encoded = base64.b64encode(image.read()).decode()
byte_img_list.append(f"data:image/jpeg;base64,{encoded}")
img_idx = clickable_images(
byte_img_list,
titles=[f"Group #{str(i)}" for i in range(len(byte_img_list))],
div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"},
img_style={"margin": "5px", "height": "150px"},
)
img_path = img_list[img_idx]
img_name = img_path.split("/")[-1]
args.pretrained_path = os.path.join("weights/{}/cpk.pth".format(img_name.split(".")[0]))
if img_idx > -1:
tester = Tester(args)
tester.define_model()
tester.load_data(img_path)
tester.load_model(args.pretrained_path)
tester.display()
if __name__ == '__main__':
os.system("pip install torch-geometric==1.7.2")
main()