maj34 commited on
Commit
b5823bd
1 Parent(s): a017c81

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -0
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as T
3
+ from dalle_pytorch import VQGanVAE
4
+ from dalle.models import DALLE_Klue_Roberta
5
+ from transformers import AutoTokenizer
6
+ import gradio as gr
7
+
8
+ import yaml
9
+ from easydict import EasyDict
10
+
11
+ dalle_config_path = 'configs/dalle_config.yaml'
12
+ dalle_path = 'results/dalle_uk_final.pt'
13
+
14
+ vqgan_config_path = '/home/brad/Development/taming-transformers/configs/VQGAN_blue.yaml'
15
+ vqgan_path = '/home/brad/Development/taming-transformers/logs/2022-07-21T12-44-12_VQGAN_blue/checkpoints/best.ckpt'
16
+
17
+ device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
18
+
19
+ tokenizer = AutoTokenizer.from_pretrained("klue/roberta-large")
20
+
21
+ with open(dalle_config_path, "r") as f:
22
+ dalle_config = yaml.load(f, Loader=yaml.Loader)
23
+ DALLE_CFG = EasyDict(dalle_config["DALLE_CFG"])
24
+
25
+ DALLE_CFG.VOCAB_SIZE = tokenizer.vocab_size
26
+
27
+ vae = VQGanVAE(
28
+ vqgan_model_path=vqgan_path,
29
+ vqgan_config_path=vqgan_config_path
30
+ )
31
+
32
+ DALLE_CFG.IMAGE_SIZE = vae.image_size
33
+
34
+ dalle_params = dict(
35
+ num_text_tokens=tokenizer.vocab_size,
36
+ text_seq_len=DALLE_CFG.TEXT_SEQ_LEN,
37
+ depth=DALLE_CFG.DEPTH,
38
+ heads=DALLE_CFG.HEADS,
39
+ dim_head=DALLE_CFG.DIM_HEAD,
40
+ reversible=DALLE_CFG.REVERSIBLE,
41
+ loss_img_weight=DALLE_CFG.LOSS_IMG_WEIGHT,
42
+ attn_types=DALLE_CFG.ATTN_TYPES,
43
+ ff_dropout=DALLE_CFG.FF_DROPOUT,
44
+ attn_dropout=DALLE_CFG.ATTN_DROPOUT,
45
+ stable=DALLE_CFG.STABLE,
46
+ shift_tokens=DALLE_CFG.SHIFT_TOKENS,
47
+ rotary_emb=DALLE_CFG.ROTARY_EMB,
48
+ )
49
+
50
+ dalle = DALLE_Klue_Roberta(
51
+ vae=vae,
52
+ wte_dir="models/roberta_large_wte.pt",
53
+ wpe_dir="models/roberta_large_wpe.pt",
54
+ **dalle_params
55
+ ).to(device)
56
+
57
+
58
+ loaded_obj = torch.load(dalle_path, map_location=torch.device('cuda:0'))
59
+ dalle_params, vae_params, weights = loaded_obj['hparams'], loaded_obj['vae_params'], loaded_obj['weights']
60
+ dalle.load_state_dict(weights)
61
+
62
+ def text_to_montage(text):
63
+ encoded_dict = tokenizer(
64
+ text,
65
+ return_tensors="pt",
66
+ padding="max_length",
67
+ truncation=True,
68
+ max_length=DALLE_CFG.TEXT_SEQ_LEN,
69
+ add_special_tokens=True,
70
+ return_token_type_ids=True, # for RoBERTa
71
+ ).to(device)
72
+
73
+ encoded_text = encoded_dict['input_ids']
74
+ mask = encoded_dict['attention_mask']
75
+
76
+ image = dalle.generate_images(
77
+ encoded_text,
78
+ mask=mask,
79
+ filter_thres=0.9 # topk sampling at 0.9
80
+ )
81
+
82
+ return T.ToPILImage()(image.squeeze())
83
+
84
+ demo = gr.Interface(fn=text_to_montage, inputs="text", outputs="image")
85
+
86
+ demo.launch(server_name="0.0.0.0")