Spaces:
Build error
Build error
import torch | |
import argparse | |
import numpy as np | |
from helper import * | |
from config.GlobalVariables import * | |
from SynthesisNetwork import SynthesisNetwork | |
from DataLoader import DataLoader | |
import convenience | |
import gradio as gr | |
#@title Demo | |
device = 'cpu' | |
num_samples = 10 | |
net = SynthesisNetwork(weight_dim=256, num_layers=3).to(device) | |
if not torch.cuda.is_available(): | |
net.load_state_dict(torch.load('./model/250000.pt', map_location=torch.device(device))["model_state_dict"]) | |
dl = DataLoader(num_writer=1, num_samples=10, divider=5.0, datadir='./data/writers') | |
writer_options = [5, 14, 15, 16, 17, 22, 25, 80, 120, 137, 147, 151] | |
all_loaded_data = [] | |
avail_char = "0 1 2 3 4 5 6 7 8 9 a b c d e f g h i j k l m n o p q r s t u v w x y z A B C D E F G H I J K L M N O P Q R S T U V W X Y Z ! ? \" ' * + - = : ; , . < > \ / [ ] ( ) # $ % &" | |
avail_char_list = avail_char.split(" ") | |
for writer_id in [120, 80]: | |
loaded_data = dl.next_batch(TYPE='TRAIN', uid=writer_id, tids=list(range(num_samples))) | |
all_loaded_data.append(loaded_data) | |
default_loaded_data = all_loaded_data[-1] | |
mdn_words = [] | |
mdn_mean_Ws = [] | |
all_word_mdn_Ws = [] | |
all_word_mdn_Cs = [] | |
# data for writer interpolation | |
writer_words = [] | |
writer_mean_Ws = [] | |
all_word_writer_Ws = [] | |
all_word_writer_Cs = [] | |
weight = 0.7 | |
def update_target_word(target_word): | |
writer_words.clear() | |
for word in target_word.split(" "): | |
writer_words.append(word) | |
writer_mean_Ws.clear() | |
for loaded_data in all_loaded_data: | |
mean_global_W = convenience.get_mean_global_W(net, loaded_data, device) | |
writer_mean_Ws.append(mean_global_W) | |
all_word_writer_Ws.clear() | |
all_word_writer_Cs.clear() | |
for word in writer_words: | |
all_writer_Ws, all_writer_Cs = convenience.get_DSD(net, word, writer_mean_Ws, all_loaded_data, device) | |
all_word_writer_Ws.append(all_writer_Ws) | |
all_word_writer_Cs.append(all_writer_Cs) | |
return update_writer_slider(weight) | |
# for writer interpolation | |
def update_writer_slider(val): | |
global weight | |
weight = val | |
net.clamp_mdn = 0 | |
im = convenience.draw_words(writer_words, all_word_writer_Ws, all_word_writer_Cs, [1 - weight, weight], net) | |
return im.convert("RGB") | |
def update_chosen_writers(writer1, writer2): | |
net.clamp_mdn = 0 | |
id1, id2 = int(writer1.split(" ")[1]), int(writer2.split(" ")[1]) | |
all_loaded_data.clear() | |
for writer_id in [id1, id2]: | |
loaded_data = dl.next_batch(TYPE='TRAIN', uid=writer_id, tids=list(range(num_samples))) | |
all_loaded_data.append(loaded_data) | |
return gr.Slider.update(label=f"{writer1} vs. {writer2}"), update_writer_slider(weight) | |
# for character blend | |
def interpolate_chars(c1, c2, weight): | |
"""Generates an image of handwritten text based on target_sentence""" | |
net.clamp_mdn = 0 | |
letters = [c1, c2] | |
character_weights = [1 - weight, weight] | |
M = len(letters) | |
mean_global_W = convenience.get_mean_global_W(net, all_loaded_data[0], device) | |
all_Cs = torch.zeros(1, M, convenience.L, convenience.L) | |
for i in range(M): # get corners of grid | |
W_vector, char_matrix = convenience.get_DSD(net, letters[i], [mean_global_W], [default_loaded_data], device) | |
all_Cs[:, i, :, :] = char_matrix | |
all_Ws = mean_global_W.reshape(1, 1, convenience.L) | |
all_W_c = convenience.get_character_blend_W_c(character_weights, all_Ws, all_Cs) | |
all_commands = convenience.get_commands(net, letters[0], all_W_c) | |
width = 60 | |
x_offset = 325 | |
im = Image.fromarray(np.zeros([160, 750])) | |
dr = ImageDraw.Draw(im) | |
for [x, y, t] in all_commands: | |
if t == 0: | |
dr.line(( | |
px + width/2 + x_offset, | |
py - width/2, # letters are shifted down for some reason | |
x + width/2 + + x_offset, | |
y - width/2), 255, 1) | |
px, py = x, y | |
return im.convert("RGB") | |
def choose_blend_chars(c1, c2): | |
return gr.Slider.update(label=f"'{c1}' vs. '{c2}'") | |
# for MDN | |
def update_mdn_word(target_word): | |
mdn_words.clear() | |
for word in target_word.split(" "): | |
mdn_words.append(word) | |
mdn_mean_Ws.clear() | |
mean_global_W = convenience.get_mean_global_W(net, default_loaded_data, device) | |
mdn_mean_Ws.append(mean_global_W) | |
all_word_mdn_Ws.clear() | |
all_word_mdn_Cs.clear() | |
for word in mdn_words: | |
all_writer_Ws, all_writer_Cs = convenience.get_DSD(net, word, mdn_mean_Ws, [default_loaded_data], device) | |
all_word_mdn_Ws.append(all_writer_Ws) | |
all_word_mdn_Cs.append(all_writer_Cs) | |
return sample_mdn(net.scale_sd, net.clamp_mdn) | |
def sample_mdn(maxs, maxr): | |
net.clamp_mdn = maxr | |
net.scale_sd = maxs | |
im = convenience.draw_words(mdn_words, all_word_mdn_Ws, all_word_mdn_Cs, [1], net) | |
return im.convert("RGB") | |
update_target_word("hello world") | |
update_mdn_word("hello world") | |
with gr.Blocks() as demo: | |
with gr.Tabs(): | |
with gr.TabItem("Blend Writers"): | |
target_word = gr.Textbox(label="Target Word", value="hello world", max_lines=1) | |
with gr.Row(): | |
left_ratio_options = ["Style " + str(id) for i, id in enumerate(writer_options) if i % 2 == 0] | |
right_ratio_options = ["Style " + str(id) for i, id in enumerate(writer_options) if i % 2 == 1] | |
with gr.Column(): | |
writer1 = gr.Radio(left_ratio_options, value="Style 120", label="Style for first writer") | |
with gr.Column(): | |
writer2 = gr.Radio(right_ratio_options, value="Style 80", label="Style for second writer") | |
with gr.Row(): | |
writer_slider = gr.Slider(0, 1, value=0.7, label="Style 120 vs. Style 80") | |
with gr.Row(): | |
writer_submit = gr.Button("Submit") | |
with gr.Row(): | |
writer_default_image = convenience.sample_blended_writers([0.3, 0.7], "hello world", net, all_loaded_data, device).convert("RGB") | |
writer_output = gr.Image(writer_default_image) | |
writer_submit.click(fn=update_writer_slider, inputs=[writer_slider], outputs=[writer_output]) | |
writer_slider.change(fn=update_writer_slider, inputs=[writer_slider], outputs=[writer_output]) | |
target_word.submit(fn=update_target_word, inputs=[target_word], outputs=[writer_output]) | |
writer1.change(fn=update_chosen_writers, inputs=[writer1, writer2], outputs=[writer_slider, writer_output]) | |
writer2.change(fn=update_chosen_writers, inputs=[writer1, writer2], outputs=[writer_slider, writer_output]) | |
with gr.TabItem("Blend Characters"): | |
with gr.Row(): | |
with gr.Column(): | |
char1 = gr.Dropdown(choices=avail_char_list, value="y", label="Character 1") | |
with gr.Column(): | |
char2 = gr.Dropdown(choices=avail_char_list, value="s", label="Character 2") | |
with gr.Row(): | |
char_slider = gr.Slider(0, 1, value=0.7, label="'y' vs. 's'") | |
with gr.Row(): | |
char_default_image = convenience.sample_blended_chars([0.3, 0.7], ["y", "s"], net, [default_loaded_data], device).convert("RGB") | |
char_output = gr.Image(char_default_image) | |
char_slider.change(fn=interpolate_chars, inputs=[char1, char2, char_slider], outputs=[char_output]) | |
char1.change(fn=choose_blend_chars, inputs=[char1, char2], outputs=[char_slider]) | |
char2.change(fn=choose_blend_chars, inputs=[char1, char2], outputs=[char_slider]) | |
with gr.TabItem("Add Randomness"): | |
mdn_word = gr.Textbox(label="Target Word", value="hello world", max_lines=1) | |
''' | |
with gr.Row(): | |
radio_options3 = ["Writer " + str(n) for n in writer_options] | |
writer = gr.Radio(radio_options3, value="Writer 80", label="Style for Writer") | |
writer.change(fn=new_writer_mdn, inputs=[writer, slider3, slider4], outputs=[output]) | |
''' | |
with gr.Row(): | |
with gr.Column(): | |
max_rand = gr.Slider(0, 1, value=1, label="Maximum Randomness") | |
with gr.Column(): | |
scale_rand = gr.Slider(0, 3, value=0.5, label="Scale of Randomness") | |
with gr.Row(): | |
mdn_sample_button = gr.Button(value="Resample!") | |
with gr.Row(): | |
default_im = convenience.mdn_single_sample("hello world", 0.5, 1, net, [default_loaded_data], device).convert('RGB') | |
mdn_output = gr.Image(default_im) | |
max_rand.change(fn=sample_mdn, inputs=[scale_rand, max_rand], outputs=[mdn_output]) | |
scale_rand.change(fn=sample_mdn, inputs=[scale_rand, max_rand], outputs=[mdn_output]) | |
mdn_sample_button.click(fn=sample_mdn, inputs=[scale_rand, max_rand], outputs=[mdn_output]) | |
mdn_word.submit(fn=update_mdn_word, inputs=[mdn_word], outputs=[mdn_output]) | |
demo.launch() | |