File size: 5,553 Bytes
b65c5e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114

import torch
import argparse
import numpy as np
from helper import *
from config.GlobalVariables import *
from SynthesisNetwork import SynthesisNetwork
from DataLoader import DataLoader
import convenience

L = 256


def main(params):
    np.random.seed(0)
    torch.manual_seed(0)

    device = 'cpu'

    net = SynthesisNetwork(weight_dim=256, num_layers=3).to(device)

    if not torch.cuda.is_available():
        try: # retrained model also contains loss in dict 
            net.load_state_dict(torch.load('./model/250000.pt', map_location=torch.device('cpu'))["model_state_dict"])
        except:
            net.load_state_dict(torch.load('./model/250000.pt', map_location=torch.device('cpu')))
        

    dl = DataLoader(num_writer=1, num_samples=10, divider=5.0, datadir='./data/writers')

    all_loaded_data = []

    for writer_id in params.writer_ids:
        loaded_data = dl.next_batch(TYPE='TRAIN', uid=writer_id, tids=list(range(params.num_samples)))
        all_loaded_data.append(loaded_data)

    
    if params.output == "image":

        if params.interpolate == "writer":
            if len(params.blend_weights) != len(params.writer_ids):
                raise ValueError("blend_weights must be same length as writer_ids")
            im = convenience.sample_blended_writers(params.blend_weights, params.target_word, net, all_loaded_data, device)
            im.convert("RGB").save(f'results/blend_{"+".join([str(i) for i in params.writer_ids])}.png')
        elif params.interpolate == "character":
            if len(params.blend_weights) != len(params.blend_chars):
                raise ValueError("blend_weights must be same length as target_word")
            im = convenience.sample_blended_chars(params.blend_weights, params.blend_chars, net, all_loaded_data, device)
            im.convert("RGB").save(f'results/blend_{"+".join(params.blend_chars)}.png')
        elif params.interpolate == "randomness":
            if not 0 <= params.max_randomness <= 1:
                raise ValueError("max_randomness must be between 0 and 1")
            im = convenience.mdn_single_sample(params.target_word, params.scale_randomness, params.max_randomness, net, all_loaded_data, device)
            im.convert("RGB").save(f"results/sample_{params.target_word.replace(' ', '_')}.png")
        else:
            raise ValueError("Invalid interpolation argument for outputting an image")
    elif params.output == "grid":

        if params.interpolate == "character":
            if len(params.grid_chars) != 4:
                raise ValueError("grid_chars must be given exactly four characters")
            im = convenience.sample_character_grid(params.grid_chars, params.grid_size, net, all_loaded_data, device)
            im.convert("RGB").save(f'results/grid_{"".join(params.grid_chars)}.png')
        else:
            raise ValueError("Invalid interpolation argument for outputting a grid")
    elif params.output == "video":

        if params.interpolate == "writer":
            convenience.writer_interpolation_video(params.target_word, params.frames_per_step, net, all_loaded_data, device)
        elif params.interpolate == "character":
            convenience.char_interpolation_video(params.blend_chars, params.frames_per_step, net, all_loaded_data, device)
        elif params.interpolate == "randomness":
            if not 0 <= params.max_randomness <= 1:
                raise ValueError("max_randomness must be between 0 and 1")
            convenience.mdn_video(params.target_word, params.num_random_samples, params.scale_randomness, params.max_randomness, net, all_loaded_data, device)
        else:
            raise ValueError("Invalid interpolation argument for outputting a video")
    else:
        raise ValueError("Invalid output")

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Arguments for generating samples with the handwriting synthesis model.')

    # parser.add_argument('--writer_id', type=int, default=80)
    parser.add_argument('--num_samples', type=int, default=10)
    parser.add_argument('--generating_default', type=int, default=0)

    parser.add_argument('--output', type=str, default="image", choices=["image", "grid", "video"])
    parser.add_argument('--interpolate', type=str, default="randomness", choices=["writer", "character", "randomness"])

    # PARAMS FOR BOTH WRITER AND CHARACTER INTERPOLATION:
        # IF IMAGE - weights to use for a single sample of interpolation
    parser.add_argument('--blend_weights', type=float, nargs="+", default=[0.5, 0.5])
        # IF VIDEO - the number of frames for each character/writer
    parser.add_argument('--frames_per_step', type=int, default=10)

    # PARAMS IF WRITER INTERPOLATION:
    parser.add_argument('--target_word', type=str, default="hello world")
    parser.add_argument('--writer_ids', type=int, nargs="+", default=[80, 120])
    
    # PARAMS IF CHARACTER INTERPOLATION:
        # IF VIDEO OR BLEND
    parser.add_argument('--blend_chars', type=str, nargs="+", default = ["a", "b", "c", "d", "e"])
        # IF GRID
    parser.add_argument('--grid_chars', type=str, nargs="+", default= ["y", "s", "u", "n"])
    parser.add_argument('--grid_size', type=int, default=10)

    # PARAMS IF RANDOMNESS ITERPOLATION (--output will be ignored):
    parser.add_argument('--max_randomness', type=float, default=1) 
    parser.add_argument('--scale_randomness', type=float, default=0.5) 
    parser.add_argument('--num_random_samples', type=int, default=10)

    main(parser.parse_args())