rom1504 commited on
Commit
3fc663b
1 Parent(s): 42559f5
Files changed (3) hide show
  1. app.py +84 -5
  2. model_paths.json +1 -0
  3. requirements.txt +5 -0
app.py CHANGED
@@ -1,10 +1,89 @@
1
- from dalle_pytorch import DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE, DALLE
2
- from dalle_pytorch.tokenizer import tokenizer, HugTokenizer, YttmTokenizer, ChineseTokenizer
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import gradio as gr
5
 
6
- def greet(name):
7
- return "Hello " + name + "!!"
8
 
9
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  iface.launch()
 
1
+ from tqdm import tqdm
2
+ import numpy as np
3
 
4
+ from pathlib import Path
5
+ import json
6
+
7
+ # torch
8
+
9
+ import torch
10
+
11
+ from einops import repeat
12
+
13
+ # vision imports
14
+
15
+ from PIL import Image
16
+
17
+ # dalle related classes and utils
18
+
19
+ from dalle_pytorch import VQGanVAE, DALLE
20
+ from dalle_pytorch.tokenizer import tokenizer
21
+
22
+ from io import BytesIO
23
  import gradio as gr
24
 
 
 
25
 
26
+ # load DALL-E
27
+
28
+ def exists(val):
29
+ return val is not None
30
+
31
+ models = json.load(open("model_paths.json"))
32
+
33
+
34
+ vae = VQGanVAE(None, None)
35
+
36
+ dalles = {}
37
+
38
+ for name, model_path in models.items():
39
+ assert Path(model_path).exists(), 'trained DALL-E '+model_path+' must exist'
40
+ load_obj = torch.load(model_path)
41
+ dalle_params, _, weights = load_obj.pop('hparams'), load_obj.pop('vae_params'), load_obj.pop('weights')
42
+ dalle_params.pop('vae', None) # cleanup later
43
+
44
+ dalle = DALLE(vae = vae, **dalle_params).cuda()
45
+
46
+ dalle.load_state_dict(weights)
47
+ dalles[name] = dalle
48
+
49
+
50
+
51
+ batch_size = 4
52
+
53
+ top_k = 0.9
54
+
55
+ # generate images
56
+
57
+ image_size = vae.image_size
58
+
59
+ def generate(text):
60
+ text_input = text
61
+ num_images = 4
62
+ dalle_name = "weird_car"
63
+ dalle = dalles[dalle_name]
64
+
65
+ text = tokenizer.tokenize([text_input], dalle.text_seq_len).cuda()
66
+
67
+ text = repeat(text, '() n -> b n', b = num_images)
68
+
69
+ outputs = []
70
+
71
+ for text_chunk in tqdm(text.split(batch_size), desc = f'generating images for - {text}'):
72
+ output = dalle.generate_images(text_chunk, filter_thres = top_k)
73
+ outputs.append(output)
74
+
75
+ outputs = torch.cat(outputs)
76
+
77
+ response = []
78
+
79
+ for image in tqdm(outputs, desc = 'saving images'):
80
+ np_image = np.moveaxis(image.cpu().numpy(), 0, -1)
81
+ formatted = (np_image * 255).astype('uint8')
82
+
83
+ img = Image.fromarray(formatted)
84
+ response.append(img)
85
+
86
+ return response
87
+
88
+ iface = gr.Interface(fn=generate, inputs="text", outputs=gr.outputs.Carousel("image"))
89
  iface.launch()
model_paths.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"weird_car":"weird_car_model_continue.pt"}
requirements.txt CHANGED
@@ -1 +1,6 @@
1
  dalle-pytorch
 
 
 
 
 
 
1
  dalle-pytorch
2
+ numpy
3
+ tqdm
4
+ torch
5
+ torchvision
6
+ einops