Spaces:
Runtime error
Runtime error
import os | |
import time | |
import json | |
import base64 | |
import argparse | |
import importlib | |
from glob import glob | |
from PIL import Image | |
from imageio import imsave | |
import torch | |
import torchvision.utils as vutils | |
import sys | |
sys.path.append(".") | |
import numpy as np | |
from libs.test_base import TesterBase | |
from libs.utils import colorEncode, label2one_hot_torch | |
from tqdm import tqdm | |
from libs.options import BaseOptions | |
import torch.nn.functional as F | |
from libs.nnutils import poolfeat, upfeat | |
import streamlit as st | |
from skimage.segmentation import slic | |
import torchvision.transforms.functional as TF | |
import torchvision.transforms as transforms | |
from st_clickable_images import clickable_images | |
args = BaseOptions().gather_options() | |
if args.img_path is not None: | |
args.exp_name = os.path.join(args.exp_name, args.img_path.split('/')[-1].split('.')[0]) | |
args.batch_size = 1 | |
args.data_path = "/home/xli/DATA/BSR_processed/train" | |
args.label_path = "/home/xli/DATA/BSR/BSDS500/data/groundTruth" | |
args.device = torch.device("cpu") | |
args.nsamples = 500 | |
args.out_dir = os.path.join('cachedir', args.exp_name) | |
os.makedirs(args.out_dir, exist_ok=True) | |
args.global_code_ch = args.hidden_dim | |
args.netG_use_noise = True | |
args.test_time = (args.test_time == 1) | |
if not hasattr(args, 'tex_code_dim'): | |
args.tex_code_dim = 256 | |
class Tester(TesterBase): | |
def define_model(self): | |
"""Define model | |
""" | |
args = self.args | |
module = importlib.import_module('models.week0417.{}'.format(args.model_name)) | |
self.model = module.AE(args) | |
self.model.to(args.device) | |
self.model.eval() | |
return | |
def draw_color_seg(self, seg): | |
seg = seg.detach().cpu().numpy() | |
color_ = [] | |
for i in range(seg.shape[0]): | |
colori = colorEncode(seg[i].squeeze()) | |
colori = torch.from_numpy(colori / 255.0).float().permute(2, 0, 1) | |
color_.append(colori) | |
color_ = torch.stack(color_) | |
return color_ | |
def to_pil(self, tensor): | |
return transforms.ToPILImage()(tensor.cpu().squeeze().clamp(0.0, 1.0)).convert("RGB") | |
def display(self): | |
with st.spinner('Running...'): | |
with torch.no_grad(): | |
grouping_mask = self.model_forward(self.data, self.slic, return_type = 'grouping') | |
data = (self.data + 1) / 2.0 | |
seg = grouping_mask.view(-1, 1, args.crop_size, args.crop_size) | |
color_vq = self.draw_color_seg(seg) | |
color_vq = color_vq * 0.8 + data.cpu() * 0.2 | |
st.markdown('<p class="big-font">Given the image you chose, our model decomposes the image into ten texture segments, each depicts one kind of texture in the image.</p>', unsafe_allow_html=True) | |
col1, col2, col3, col4 = st.columns(4) | |
with col1: | |
st.markdown("") | |
with col2: | |
st.markdown("Chosen image") | |
st.image(self.to_pil(data)) | |
with col3: | |
st.markdown("Grouping mask") | |
st.image(self.to_pil(color_vq)) | |
with col4: | |
st.markdown("") | |
seg_onehot = label2one_hot_torch(seg, C = 10) | |
parts = data.cpu() * seg_onehot.squeeze().unsqueeze(1) | |
st.markdown('<p class="big-font">We show all texture segments below. To synthesize an arbitrary-sized texture image from a texture segment, choose and click one of the texture segments below.</p>', unsafe_allow_html=True) | |
tmp_img_list = [] | |
for i in range(parts.shape[0]): | |
part_img = self.to_pil(parts[i]) | |
out_path = 'tmp/{}.png'.format(i) | |
part_img.save(out_path) | |
with open(out_path, "rb") as image: | |
encoded = base64.b64encode(image.read()).decode() | |
tmp_img_list.append(f"data:image/jpeg;base64,{encoded}") | |
tex_idx = clickable_images( | |
tmp_img_list, | |
titles=[f"Group #{str(i)}" for i in range(len(tmp_img_list))], | |
div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"}, | |
img_style={"margin": "5px", "height": "150px"}, | |
key=0 | |
) | |
if tex_idx > -1: | |
with st.spinner('Running...'): | |
st.markdown('<p class="big-font">You can slide the bar below to set the size of the synthesized texture image.</p>', unsafe_allow_html=True) | |
tex_size = st.slider('', 0, 1000, 256) | |
tex_size = (tex_size // 8) * 8 | |
with torch.no_grad(): | |
tex = self.model_forward(self.data, self.slic, tex_idx = tex_idx, tex_size = tex_size, return_type = 'tex') | |
col1, col2, col3, col4 = st.columns([1, 1, 4, 1]) | |
with col1: | |
st.markdown("") | |
with col2: | |
st.markdown("Chosen examplar segment") | |
st.image(self.to_pil(parts[tex_idx])) | |
with col3: | |
st.markdown("Synthesized texture image") | |
st.image(self.to_pil(tex)) | |
with col4: | |
st.markdown("") | |
st.markdown('<p class="big-font">You can choose another image from the examplar images on the top and start again!</p>', unsafe_allow_html=True) | |
#torch.cuda.empty_cache() | |
""" | |
st.markdown("#### Texture Editing") | |
st.markdown("**Choose one texture segment to remove.**") | |
remove_idx = clickable_images( | |
tmp_img_list, | |
titles=[f"Group #{str(i)}" for i in range(len(tmp_img_list))], | |
div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"}, | |
img_style={"margin": "5px", "height": "120px"}, | |
key=1 | |
) | |
st.markdown("**Choose one texture segment to fill in the missing pixels.**") | |
fill_idx = clickable_images( | |
tmp_img_list, | |
titles=[f"Group #{str(i)}" for i in range(len(tmp_img_list))], | |
div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"}, | |
img_style={"margin": "5px", "height": "120px"}, | |
key=2 | |
) | |
rec = self.model_forward(self.data, self.slic, return_type = 'editing', fill_idx = fill_idx, remove_idx = remove_idx) | |
st.image(self.to_pil(rec)) | |
""" | |
def model_forward(self, rgb_img, slic, epoch = 1000, test_time = False, | |
test = True, tex_idx = None, tex_size = 256, | |
return_type = 'tex', fill_idx = None, remove_idx = None): | |
args = self.args | |
B, _, imgH, imgW = rgb_img.shape | |
# Encoder: img (B, 3, H, W) -> feature (B, C, imgH//8, imgW//8) | |
conv_feat, _ = self.model.enc(rgb_img) | |
B, C, H, W = conv_feat.shape | |
# Texture code for each superpixel | |
tex_code = self.model.ToTexCode(conv_feat) | |
code = F.interpolate(tex_code, size = (imgH, imgW), mode = 'bilinear', align_corners = False) | |
pool_code = poolfeat(code, slic, avg = True) | |
prop_code, sp_assign, conv_feats = self.model.gcn(pool_code, slic, (args.add_clustering_epoch <= epoch)) | |
softmax = F.softmax(sp_assign * args.temperature, dim = 1) | |
if return_type == 'grouping': | |
return torch.argmax(sp_assign.cpu(), dim = 1) | |
tex_seg = poolfeat(conv_feats, softmax, avg = True) | |
seg = label2one_hot_torch(torch.argmax(softmax, dim = 1).unsqueeze(1), C = softmax.shape[1]) | |
if return_type == 'tex': | |
sampled_code = tex_seg[:, tex_idx, :] | |
rec_tex = sampled_code.view(1, -1, 1, 1).repeat(1, 1, tex_size, tex_size) | |
sine_wave = self.model.get_sine_wave(rec_tex, 'rec') | |
H = tex_size // 8; W = tex_size // 8 | |
noise = torch.randn(B, self.model.sine_wave_dim, H, W).to(tex_code.device) | |
dec_input = torch.cat((sine_wave, noise), dim = 1) | |
weight = self.model.ChannelWeight(rec_tex) | |
weight = F.adaptive_avg_pool2d(weight, output_size = (1)).view(weight.shape[0], -1, 1, 1) | |
weight = torch.sigmoid(weight) | |
dec_input *= weight | |
rep_rec = self.model.G(dec_input, rec_tex) | |
rep_rec = (rep_rec + 1) / 2.0 | |
return rep_rec | |
elif return_type == 'editing': | |
remove_mask = 0 | |
fill_mask = 1 | |
rec_tex = upfeat(tex_seg, seg) | |
remove_mask = seg[:, remove_idx:remove_idx+1] | |
fill_tex = tex_seg[:, fill_idx, :].view(1, -1, 1, 1).repeat(1, 1, imgH, imgW) | |
rec_tex = rec_tex * (1 - remove_mask) + fill_tex * remove_mask | |
sine_wave = self.model.get_sine_wave(rec_tex, 'rec') | |
H = imgH // 8; W = imgW // 8 | |
noise = torch.randn(B, self.model.sine_wave_dim, H, W).to(tex_code.device) | |
dec_input = torch.cat((sine_wave, noise), dim = 1) | |
weight = self.model.ChannelWeight(rec_tex) | |
weight = F.adaptive_avg_pool2d(weight, output_size = (1)).view(weight.shape[0], -1, 1, 1) | |
weight = torch.sigmoid(weight) | |
dec_input *= weight | |
rep_rec = self.model.G(dec_input, rec_tex) | |
rep_rec = (rep_rec + 1) / 2.0 | |
return rep_rec | |
def load_data(self, data_path): | |
rgb_img = Image.open(data_path) | |
crop_size = self.args.crop_size | |
i = 40; j = 40; h = crop_size; w = crop_size | |
rgb_img = TF.crop(rgb_img, i, j, h, w) | |
# compute superpixel | |
sp_num = 196 | |
slic_i = slic(np.array(rgb_img), n_segments=sp_num, compactness=10, start_label=0, min_size_factor=0.3) | |
slic_i = torch.from_numpy(slic_i) | |
slic_i[slic_i >= sp_num] = sp_num - 1 | |
oh = label2one_hot_torch(slic_i.unsqueeze(0).unsqueeze(0), C = sp_num).squeeze() | |
self.slic = oh.unsqueeze(0).to(args.device) | |
rgb_img = TF.to_tensor(rgb_img) | |
rgb_img = rgb_img.unsqueeze(0) | |
self.data = rgb_img.to(args.device) * 2 - 1 | |
def load_model(self, model_path): | |
self.model = torch.nn.DataParallel(self.model) | |
cpk = torch.load(model_path, map_location=torch.device('cpu')) | |
saved_state_dict = cpk['model'] | |
self.model.load_state_dict(saved_state_dict) | |
self.model = self.model.module | |
return | |
def test(self): | |
""" Test function | |
""" | |
#for iteration in tqdm(range(args.nsamples)): | |
self.test_step(0) | |
self.display(0, 'train') | |
def main(): | |
#torch.cuda.empty_cache() | |
st.set_page_config(layout="wide") | |
st.markdown(""" | |
<style> | |
.big-font { | |
font-size:30px !important; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
st.title("Scraping Textures from Natural Images for Synthesis and Editing") | |
#st.markdown("**In this demo, we show how to scrape textures from natural images for texture synthesis and editing.**") | |
st.markdown('<p class="big-font">In this demo, we show how to scrape textures from natural images for texture synthesis and editing.</p>', unsafe_allow_html=True) | |
st.markdown("## Texture synthesis") | |
st.markdown('<p class="big-font">Here we provide a set of example images, please choose and click one image to start.</p>', unsafe_allow_html=True) | |
img_list = glob(os.path.join("data/images/*.jpg")) | |
test_img_list = glob(os.path.join("data/test_images/*.jpg")) | |
img_list.extend(test_img_list) | |
byte_img_list = [] | |
for img_path in img_list: | |
with open(img_path, "rb") as image: | |
encoded = base64.b64encode(image.read()).decode() | |
byte_img_list.append(f"data:image/jpeg;base64,{encoded}") | |
img_idx = clickable_images( | |
byte_img_list, | |
titles=[f"Group #{str(i)}" for i in range(len(byte_img_list))], | |
div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"}, | |
img_style={"margin": "5px", "height": "150px"}, | |
) | |
img_path = img_list[img_idx] | |
img_name = img_path.split("/")[-1] | |
args.pretrained_path = os.path.join("weights/{}/cpk.pth".format(img_name.split(".")[0])) | |
if img_idx > -1: | |
tester = Tester(args) | |
tester.define_model() | |
tester.load_data(img_path) | |
tester.load_model(args.pretrained_path) | |
tester.display() | |
if __name__ == '__main__': | |
os.system("pip install torch-geometric==1.7.2") | |
main() | |