Spaces:
Runtime error
Runtime error
import os | |
import time | |
import json | |
import base64 | |
import argparse | |
import importlib | |
from glob import glob | |
from PIL import Image | |
from imageio import imsave | |
import torch | |
import torchvision.utils as vutils | |
import sys | |
sys.path.append(".") | |
import numpy as np | |
from libs.test_base import TesterBase | |
from libs.utils import colorEncode, label2one_hot_torch | |
from tqdm import tqdm | |
from libs.options import BaseOptions | |
import torch.nn.functional as F | |
from libs.nnutils import poolfeat, upfeat | |
import streamlit as st | |
from skimage.segmentation import slic | |
import torchvision.transforms.functional as TF | |
import torchvision.transforms as transforms | |
from st_clickable_images import clickable_images | |
args = BaseOptions().gather_options() | |
if args.img_path is not None: | |
args.exp_name = os.path.join(args.exp_name, args.img_path.split('/')[-1].split('.')[0]) | |
args.batch_size = 1 | |
args.data_path = "/home/xli/DATA/BSR_processed/train" | |
args.label_path = "/home/xli/DATA/BSR/BSDS500/data/groundTruth" | |
args.device = torch.device("cpu") | |
args.nsamples = 500 | |
args.out_dir = os.path.join('cachedir', args.exp_name) | |
os.makedirs(args.out_dir, exist_ok=True) | |
args.global_code_ch = args.hidden_dim | |
args.netG_use_noise = True | |
args.test_time = (args.test_time == 1) | |
if not hasattr(args, 'tex_code_dim'): | |
args.tex_code_dim = 256 | |
class Tester(TesterBase): | |
def define_model(self): | |
"""Define model | |
""" | |
args = self.args | |
module = importlib.import_module('models.week0417.{}'.format(args.model_name)) | |
self.model = module.AE(args) | |
self.model.to(args.device) | |
self.model.eval() | |
return | |
def draw_color_seg(self, seg): | |
seg = seg.detach().cpu().numpy() | |
color_ = [] | |
for i in range(seg.shape[0]): | |
colori = colorEncode(seg[i].squeeze()) | |
colori = torch.from_numpy(colori / 255.0).float().permute(2, 0, 1) | |
color_.append(colori) | |
color_ = torch.stack(color_) | |
return color_ | |
def to_pil(self, tensor): | |
return transforms.ToPILImage()(tensor.cpu().squeeze().clamp(0.0, 1.0)).convert("RGB") | |
def display_synthesis(self): | |
with st.spinner('Running...'): | |
with torch.no_grad(): | |
grouping_mask = self.model_forward_synthesis(self.data, self.slic, return_type = 'grouping') | |
data = (self.data + 1) / 2.0 | |
seg = grouping_mask.view(-1, 1, args.crop_size, args.crop_size) | |
color_vq = self.draw_color_seg(seg) | |
color_vq = color_vq * 0.8 + data.cpu() * 0.2 | |
st.markdown('<p class="big-font">Given the image you chose, our model decomposes the image into ten texture segments, each depicts one kind of texture in the image.</p>', unsafe_allow_html=True) | |
col1, col2, col3, col4 = st.columns(4) | |
with col1: | |
st.markdown("") | |
with col2: | |
st.markdown("Chosen image") | |
st.image(self.to_pil(data)) | |
with col3: | |
st.markdown("Grouping mask") | |
st.image(self.to_pil(color_vq)) | |
with col4: | |
st.markdown("") | |
seg_onehot = label2one_hot_torch(seg, C = 10) | |
parts = data.cpu() * seg_onehot.squeeze().unsqueeze(1) | |
st.markdown('<p class="big-font">We show all texture segments below. To synthesize an arbitrary-sized texture image from a texture segment, choose and click one of the texture segments below.</p>', unsafe_allow_html=True) | |
tmp_img_list = [] | |
for i in range(parts.shape[0]): | |
part_img = self.to_pil(parts[i]) | |
out_path = 'tmp/{}.png'.format(i) | |
part_img.save(out_path) | |
with open(out_path, "rb") as image: | |
encoded = base64.b64encode(image.read()).decode() | |
tmp_img_list.append(f"data:image/jpeg;base64,{encoded}") | |
tex_idx = clickable_images( | |
tmp_img_list, | |
titles=[f"Group #{str(i)}" for i in range(len(tmp_img_list))], | |
div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"}, | |
img_style={"margin": "5px", "height": "150px"}, | |
key=0 | |
) | |
if tex_idx > -1: | |
with st.spinner('Running...'): | |
st.markdown('<p class="big-font">You can slide the bar below to set the size of the synthesized texture image.</p>', unsafe_allow_html=True) | |
tex_size = st.slider('', 0, 1000, 256) | |
tex_size = (tex_size // 8) * 8 | |
with torch.no_grad(): | |
tex = self.model_forward_synthesis(self.data, self.slic, tex_idx = tex_idx, tex_size = tex_size, return_type = 'tex') | |
col1, col2, col3, col4 = st.columns([1, 1, 4, 1]) | |
with col1: | |
st.markdown("") | |
with col2: | |
st.markdown("Chosen examplar segment") | |
st.image(self.to_pil(parts[tex_idx])) | |
with col3: | |
st.markdown("Synthesized texture image") | |
st.image(self.to_pil(tex)) | |
with col4: | |
st.markdown("") | |
st.markdown('<p class="big-font">You can choose another image from the examplar images on the top and start again!</p>', unsafe_allow_html=True) | |
def model_forward_synthesis(self, rgb_img, slic, epoch = 1000, test_time = False, | |
test = True, tex_idx = None, tex_size = 256, | |
return_type = 'tex', fill_idx = None, remove_idx = None): | |
args = self.args | |
B, _, imgH, imgW = rgb_img.shape | |
# Encoder: img (B, 3, H, W) -> feature (B, C, imgH//8, imgW//8) | |
conv_feat, _ = self.model.enc(rgb_img) | |
B, C, H, W = conv_feat.shape | |
# Texture code for each superpixel | |
tex_code = self.model.ToTexCode(conv_feat) | |
code = F.interpolate(tex_code, size = (imgH, imgW), mode = 'bilinear', align_corners = False) | |
pool_code = poolfeat(code, slic, avg = True) | |
prop_code, sp_assign, conv_feats = self.model.gcn(pool_code, slic, (args.add_clustering_epoch <= epoch)) | |
softmax = F.softmax(sp_assign * args.temperature, dim = 1) | |
if return_type == 'grouping': | |
return torch.argmax(sp_assign.cpu(), dim = 1) | |
tex_seg = poolfeat(conv_feats, softmax, avg = True) | |
seg = label2one_hot_torch(torch.argmax(softmax, dim = 1).unsqueeze(1), C = softmax.shape[1]) | |
sampled_code = tex_seg[:, tex_idx, :] | |
rec_tex = sampled_code.view(1, -1, 1, 1).repeat(1, 1, tex_size, tex_size) | |
sine_wave = self.model.get_sine_wave(rec_tex, 'rec') | |
H = tex_size // 8; W = tex_size // 8 | |
noise = torch.randn(B, self.model.sine_wave_dim, H, W).to(tex_code.device) | |
dec_input = torch.cat((sine_wave, noise), dim = 1) | |
weight = self.model.ChannelWeight(rec_tex) | |
weight = F.adaptive_avg_pool2d(weight, output_size = (1)).view(weight.shape[0], -1, 1, 1) | |
weight = torch.sigmoid(weight) | |
dec_input *= weight | |
rep_rec = self.model.G(dec_input, rec_tex) | |
rep_rec = (rep_rec + 1) / 2.0 | |
return rep_rec | |
def display_editing(self): | |
with st.spinner('Running...'): | |
with torch.no_grad(): | |
grouping_mask = self.model_forward_editing(self.data, self.slic, return_type = 'grouping') | |
data = (self.data + 1) / 2.0 | |
seg = grouping_mask.view(-1, 1, args.crop_size, args.crop_size) | |
color_vq = self.draw_color_seg(seg) | |
color_vq = color_vq * 0.8 + data.cpu() * 0.2 | |
st.markdown('<p class="big-font">Given the image you chose, our model decomposes the image into ten texture segments, each depicts one kind of texture in the image.</p>', unsafe_allow_html=True) | |
col1, col2, col3, col4 = st.columns(4) | |
with col1: | |
st.markdown("") | |
with col2: | |
st.markdown("Chosen image") | |
st.image(self.to_pil(data)) | |
with col3: | |
st.markdown("Grouping mask") | |
st.image(self.to_pil(color_vq)) | |
with col4: | |
st.markdown("") | |
seg_onehot = label2one_hot_torch(seg, C = 10) | |
parts = data.cpu() * seg_onehot.squeeze().unsqueeze(1) | |
st.markdown('<p class="big-font">We show all texture segments below.</p>', unsafe_allow_html=True) | |
tmp_img_list = [] | |
for i in range(parts.shape[0]): | |
part_img = self.to_pil(parts[i]) | |
out_path = 'tmp/{}.png'.format(i) | |
part_img.save(out_path) | |
with open(out_path, "rb") as image: | |
encoded = base64.b64encode(image.read()).decode() | |
tmp_img_list.append(f"data:image/jpeg;base64,{encoded}") | |
tex_idx = clickable_images( | |
tmp_img_list, | |
titles=[f"Group #{str(i)}" for i in range(len(tmp_img_list))], | |
div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"}, | |
img_style={"margin": "5px", "height": "150px"}, | |
key=2 | |
) | |
st.markdown('<p class="big-font">Choose one mask for texture editing.</p>', unsafe_allow_html=True) | |
mask_list = glob(os.path.join("data/masks/*.png")) | |
byte_mask_list = [] | |
for img_path in mask_list: | |
seg = Image.open(img_path).convert("L") | |
seg = np.asarray(seg) | |
seg = torch.from_numpy(seg).view(1, 1, seg.shape[0], seg.shape[1]) | |
color_vq = self.draw_color_seg(seg) | |
vutils.save_image(color_vq, 'tmp/tmp.png') | |
with open('tmp/tmp.png', "rb") as image: | |
encoded = base64.b64encode(image.read()).decode() | |
byte_mask_list.append(f"data:image/jpeg;base64,{encoded}") | |
img_idx = clickable_images( | |
byte_mask_list, | |
titles=[f"Group #{str(i)}" for i in range(len(byte_mask_list))], | |
div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"}, | |
img_style={"margin": "5px", "height": "150px"}, | |
) | |
mask_path = mask_list[img_idx] | |
st.markdown('<p class="big-font">Choose the texture segment for each group in the given mask below.</p>', unsafe_allow_html=True) | |
given_mask = Image.open(mask_path).convert("L") | |
given_mask = np.asarray(given_mask) | |
given_mask = torch.from_numpy(given_mask) | |
H, W = given_mask.shape[0], given_mask.shape[1] | |
given_mask = label2one_hot_torch(given_mask.view(1, 1, H, W), C = (given_mask.max()+1)) | |
mask_img_list = [] | |
for i in range(given_mask.shape[1]): | |
part_img = self.to_pil(given_mask[0, i]) | |
out_path = 'tmp/{}.png'.format(i) | |
part_img.save(out_path) | |
with open(out_path, "rb") as image: | |
encoded = base64.b64encode(image.read()).decode() | |
mask_img_list.append(f"data:image/jpeg;base64,{encoded}") | |
part_idx = clickable_images( | |
mask_img_list, | |
div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"}, | |
img_style={"margin": "5px", "height": "150px"}, | |
key=1 | |
) | |
cols = st.columns(len(mask_img_list)) | |
options = [] | |
for i, col in enumerate(cols): | |
with col: | |
option = st.selectbox( | |
"", | |
([str(ii) for ii in range(10)]), | |
key = i) | |
options.append(int(option)) | |
print(options) | |
if len(options) > 0: | |
with st.spinner('Running...'): | |
st.markdown('<p class="big-font">Edited image is shown below.</p>', unsafe_allow_html=True) | |
#tex_size = st.slider('', 0, 1000, 256) | |
#tex_size = (tex_size // 8) * 8 | |
with torch.no_grad(): | |
edited = self.model_forward_editing(self.data, self.slic, options=options, given_mask=given_mask, return_type = 'edited') | |
col1, col2, col3, col4 = st.columns([1, 1, 4, 1]) | |
with col1: | |
st.markdown("") | |
with col2: | |
st.markdown("Input image") | |
img = F.interpolate(self.data, size = edited.shape[-2:], mode = 'bilinear', align_corners = False) | |
st.image(self.to_pil((img + 1) / 2.0)) | |
print(img.shape, edited.shape) | |
with col3: | |
st.markdown("Synthesized texture image") | |
st.image(self.to_pil(edited)) | |
with col4: | |
st.markdown("") | |
st.markdown('<p class="big-font">You can choose another image from the examplar images on the top and start again!</p>', unsafe_allow_html=True) | |
def model_forward_editing(self, rgb_img, slic, epoch = 1000, test_time = False, | |
test = True, tex_idx = None, tex_size = 256, | |
return_type = 'edited', fill_idx = None, remove_idx = None, | |
options = None, given_mask = None): | |
args = self.args | |
B, _, imgH, imgW = rgb_img.shape | |
# Encoder: img (B, 3, H, W) -> feature (B, C, imgH//8, imgW//8) | |
conv_feat, _ = self.model.enc(rgb_img) | |
B, C, H, W = conv_feat.shape | |
# Texture code for each superpixel | |
tex_code = self.model.ToTexCode(conv_feat) | |
code = F.interpolate(tex_code, size = (imgH, imgW), mode = 'bilinear', align_corners = False) | |
pool_code = poolfeat(code, slic, avg = True) | |
prop_code, sp_assign, conv_feats = self.model.gcn(pool_code, slic, (args.add_clustering_epoch <= epoch)) | |
softmax = F.softmax(sp_assign * args.temperature, dim = 1) | |
if return_type == 'grouping': | |
return torch.argmax(sp_assign.cpu(), dim = 1) | |
tex_seg = poolfeat(conv_feats, softmax, avg = True) | |
seg = label2one_hot_torch(torch.argmax(softmax, dim = 1).unsqueeze(1), C = softmax.shape[1]) | |
given_mask = F.interpolate(given_mask, size = (512, 512), mode = 'bilinear', align_corners = False) | |
rec_tex = torch.zeros((1, tex_seg.shape[-1], 512, 512)) | |
for i in range(given_mask.shape[1]): | |
label = options[i] | |
code = tex_seg[0, label, :].view(1, -1, 1, 1).repeat(1, 1, 512, 512) | |
rec_tex += code * given_mask[:, i:i+1] | |
tex_size = 512 | |
sine_wave = self.model.get_sine_wave(rec_tex, 'rec') | |
H = tex_size // 8; W = tex_size // 8 | |
noise = torch.randn(B, self.model.sine_wave_dim, H, W).to(tex_code.device) | |
dec_input = torch.cat((sine_wave, noise), dim = 1) | |
weight = self.model.ChannelWeight(rec_tex) | |
weight = F.adaptive_avg_pool2d(weight, output_size = (1)).view(weight.shape[0], -1, 1, 1) | |
weight = torch.sigmoid(weight) | |
dec_input *= weight | |
rep_rec = self.model.G(dec_input, rec_tex) | |
rep_rec = (rep_rec + 1) / 2.0 | |
return rep_rec | |
def load_data(self, data_path): | |
rgb_img = Image.open(data_path) | |
crop_size = self.args.crop_size | |
i = 40; j = 40; h = crop_size; w = crop_size | |
rgb_img = TF.crop(rgb_img, i, j, h, w) | |
# compute superpixel | |
sp_num = 196 | |
slic_i = slic(np.array(rgb_img), n_segments=sp_num, compactness=10, start_label=0, min_size_factor=0.3) | |
slic_i = torch.from_numpy(slic_i) | |
slic_i[slic_i >= sp_num] = sp_num - 1 | |
oh = label2one_hot_torch(slic_i.unsqueeze(0).unsqueeze(0), C = sp_num).squeeze() | |
self.slic = oh.unsqueeze(0).to(args.device) | |
rgb_img = TF.to_tensor(rgb_img) | |
rgb_img = rgb_img.unsqueeze(0) | |
self.data = rgb_img.to(args.device) * 2 - 1 | |
def load_model(self, model_path): | |
self.model = torch.nn.DataParallel(self.model) | |
cpk = torch.load(model_path, map_location=torch.device('cpu')) | |
saved_state_dict = cpk['model'] | |
self.model.load_state_dict(saved_state_dict) | |
self.model = self.model.module | |
return | |
""" | |
def test(self): | |
#for iteration in tqdm(range(args.nsamples)): | |
self.test_step(0) | |
self.display(0, 'train') | |
""" | |
def main(): | |
#torch.cuda.empty_cache() | |
st.set_page_config(layout="wide") | |
st.markdown(""" | |
<style> | |
.big-font { | |
font-size:30px !important; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
st.title("Scraping Textures from Natural Images for Synthesis and Editing") | |
#st.markdown("**In this demo, we show how to scrape textures from natural images for texture synthesis and editing.**") | |
st.markdown('<p class="big-font">In this demo, we show how to scrape textures from natural images for texture synthesis and editing.</p>', unsafe_allow_html=True) | |
st.markdown("## Texture synthesis") | |
st.markdown('<p class="big-font">Here we provide a set of example images, please choose and click one image to start.</p>', unsafe_allow_html=True) | |
img_list = glob(os.path.join("data/images/*.jpg")) | |
test_img_list = glob(os.path.join("data/test_images/*.jpg")) | |
img_list.extend(test_img_list) | |
byte_img_list = [] | |
for img_path in img_list: | |
with open(img_path, "rb") as image: | |
encoded = base64.b64encode(image.read()).decode() | |
byte_img_list.append(f"data:image/jpeg;base64,{encoded}") | |
img_idx = clickable_images( | |
byte_img_list, | |
titles=[f"Group #{str(i)}" for i in range(len(byte_img_list))], | |
div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"}, | |
img_style={"margin": "5px", "height": "150px"}, | |
) | |
img_path = img_list[img_idx] | |
img_name = img_path.split("/")[-1] | |
args.pretrained_path = os.path.join("weights/{}/cpk.pth".format(img_name.split(".")[0])) | |
if img_idx > -1: | |
tester = Tester(args) | |
tester.define_model() | |
tester.load_data(img_path) | |
tester.load_model(args.pretrained_path) | |
app_idx = st.selectbox('Please select between texture synthesis or editing', | |
["Texture Synthesis", "Texture Editing"]) | |
if app_idx == 'Texture Editing': | |
st.header("Texture Editing") | |
tester.display_editing() | |
else: | |
st.header("Texture Synthesis") | |
tester.display_synthesis() | |
if __name__ == '__main__': | |
os.system("pip install torch-geometric==1.7.2") | |
main() | |