File size: 2,358 Bytes
f7009b3
181711a
f7009b3
 
 
181711a
f7009b3
181711a
f7009b3
 
 
 
181711a
 
f7009b3
 
 
 
 
 
181711a
f7009b3
 
 
 
 
 
 
 
 
 
 
 
181711a
 
 
f7009b3
 
 
181711a
 
 
 
 
 
 
 
 
 
 
f7009b3
181711a
 
 
f7009b3
 
181711a
 
 
 
 
 
f7009b3
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import sys, os, json
root = os.sep + os.sep.join(__file__.split(os.sep)[1:__file__.split(os.sep).index("Recurrent-Parameter-Generation") + 1])
sys.path.append(root)
os.chdir(root)


# torch
import time
import torch
from torch import nn
# father
from workspace.classinput import generalization as item


train_set = item.train_set
test_set = item.test_set
train_set.set_infinite_dataset(max_num=train_set.real_length)
print("num_generated:", test_set.real_length)
config = item.config
model = item.model
model.config["diffusion_batch"] = 128
assert config.get("tag") is not None, "Remember to set a tag."

# Model
print('==> Building model..')
diction = torch.load("./checkpoint/generalization.pth")
permutation_shape = diction["to_permutation_state.weight"].shape
model.to_permutation_state = nn.Embedding(*permutation_shape)
model.load_state_dict(diction)
model = model.cuda()

# generate
print('==> Defining generate..')


def generate(save_path, embedding, need_test=True):
    print("\n==> Generating..")
    model.eval()
    with torch.no_grad():
        start_time = time.time()
        prediction = None
        def display_time():
            while prediction is None:
                elapsed_time = time.time() - start_time
                elapsed_minutes = elapsed_time / 60
                print(f"Elapsed time: {elapsed_minutes:.2f} minutes", end="\r")
                time.sleep(0.1)
        import threading
        timer_thread = threading.Thread(target=display_time)
        timer_thread.start()
        prediction = model(sample=True, condition=embedding[None], permutation_state=False)
        timer_thread.join()
        print()
    generated_norm = torch.nanmean((prediction.cpu()).abs())
    print("Generated_norm:", generated_norm.item())
    if need_test:
        real_emb = input("Input your expected class (ONLY FOR EVALUATING): ")
        # real_emb = "[0,0,1,1,1,1,1,1,0,0]"
        real_emb = torch.tensor(eval(real_emb), dtype=torch.float)
        class_index = str(int("".join([str(int(i)) for i in real_emb]), 2)).zfill(4)
        train_set.save_params(prediction, save_path=save_path.format(class_index))
        print("Saved to:", save_path.format(class_index))
        test_command = os.path.join(test_set.test_command + save_path.format(class_index))
        os.system(test_command)
    model.train()
    return prediction