Khalil commited on
Commit
b41a54a
1 Parent(s): 416f940

First commit, add text2punps scripts, app file, and requirements file

Browse files
app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # system
2
+
3
+ import os
4
+
5
+ os.system("gdown https://drive.google.com/uc?id=1--27E5dk8GzgvpVL0ofr-m631iymBpUH")
6
+ os.system("gdown https://drive.google.com/uc?id=191a5lTsUPQ1hXaeo6kVNbo_W3WYuXsmF")
7
+
8
+ # plot
9
+
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+ from PIL import Image
13
+
14
+ # gradio
15
+
16
+ import gradio as gr
17
+
18
+ # text2punks utils
19
+
20
+ from text2punks.utils import to_pil_image, model_loader, generate_image
21
+
22
+
23
+ batch_size = 32
24
+ num_images = 32
25
+ top_prediction = 8
26
+
27
+ # nobs to tune
28
+
29
+ top_k = 0.8
30
+ temperature = 1.25
31
+
32
+ # helper functions
33
+
34
+ def compose_predictions(images):
35
+
36
+ increased_h = 0
37
+ h, w = images[0].shape[0], images[0].shape[1]
38
+ image_grid = Image.new("RGB", (len(images)*w, h))
39
+
40
+ for i, img_ in enumerate(images):
41
+ image_grid.paste(to_pil_image(img_), (i*w, increased_h))
42
+
43
+ return img
44
+
45
+
46
+ def run_inference(prompt, num_images=32, num_preds=8):
47
+
48
+ t2p_path, clip_path = './Text2Punk-final-7.pt', './clip-final.pt'
49
+ text2punk, clip = model_loader(t2p_path, clip_path)
50
+
51
+ images = generate_image(prompt_text=prompt, top_k=top_k, temperature=temperature, num_images=num_images, batch_size=batch_size, top_prediction=top_prediction, text2punk_model=text2punk, clip_model=clip)
52
+ predictions = compose_predictions(images)
53
+
54
+ output_title = f"""
55
+ <b>{prompt}</b>
56
+ """
57
+
58
+ return (output_title, predictions)
59
+
60
+
61
+ outputs = [
62
+ gr.outputs.HTML(label=""), # To be used as title
63
+ gr.outputs.Image(label=''),
64
+ ]
65
+
66
+ description = """
67
+ Text2Cryptopunks is an AI model that generates Cryptopunks images from text prompt:
68
+ """
69
+
70
+ gr.Interface(run_inference,
71
+ inputs=[gr.inputs.Textbox(label='type somthing like this : "An Ape CryptoPunk that has 2 Attributes, a Pigtails and a Medical Mask."')],
72
+ outputs=outputs,
73
+ title='Text2Cryptopunks',
74
+ description=description,
75
+ article="<p style='text-align: center'> Created by kTonpa | <a href='https://github.com/kTonpa/Text2CryptoPunks'>GitHub</a>",
76
+ layout='vertical',
77
+ theme='huggingface',
78
+ examples=[['Cute Alien cryptopunk that has a 2 Attributes, a Pipe, and a Beanie.'], ['A low resolution photo of punky-looking Ape that has 2 Attributes, a Beanie, and a Medical Mask.']],
79
+ allow_flagging=False,
80
+ live=False,
81
+ # server_port=8999
82
+ ).launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ einops
4
+ numpy
5
+ ftfy
6
+ regex
7
+ axial-positional-embedding
8
+ youtokentome
9
+ tokenizers
text2punks/attention.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, einsum
3
+ import torch.nn.functional as F
4
+ from einops import rearrange, repeat
5
+
6
+ # helpers
7
+
8
+ def exists(val):
9
+ return val is not None
10
+
11
+ def max_neg_value(t):
12
+ return -torch.finfo(t.dtype).max
13
+
14
+
15
+ # classes
16
+
17
+ class Attention(nn.Module):
18
+ def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, attn_dropout = 0., resid_dropout = 0.):
19
+ super().__init__()
20
+ inner_dim = dim_head * heads
21
+ self.heads = heads
22
+ self.seq_len = seq_len
23
+ self.scale = dim_head ** -0.5
24
+
25
+ self.causal = causal
26
+ self.attn_drop = nn.Dropout(attn_dropout)
27
+
28
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
29
+ self.to_out = nn.Sequential(
30
+ nn.Linear(inner_dim, dim),
31
+ nn.Dropout(resid_dropout)
32
+ )
33
+
34
+ def forward(self, x):
35
+ h, device = self.heads, x.device
36
+
37
+ qkv = self.to_qkv(x).chunk(3, dim = -1)
38
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
39
+
40
+ q = q * self.scale
41
+
42
+ dots = torch.einsum('b h i d, b h j d -> b h i j', q, k)
43
+ mask_value = max_neg_value(dots)
44
+
45
+ if self.causal:
46
+ i, j = dots.shape[-2:]
47
+ mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
48
+ dots.masked_fill_(mask, mask_value)
49
+
50
+ attn = torch.softmax(dots, dim=-1)
51
+ attn = self.attn_drop(attn)
52
+
53
+ out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
54
+ out = rearrange(out, 'b h n d -> b n (h d)')
55
+ out = self.to_out(out)
56
+ return out
57
+
58
+
59
+ # sparse axial causal attention
60
+
61
+ class SparseAxialCausalAttention(nn.Module):
62
+ def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head = 64, attn_dropout = 0., resid_dropout = 0.):
63
+ super().__init__()
64
+ assert axis in {0, 1}, 'axis must be either 0 (along height) or 1 (along width)'
65
+ self.axis = axis
66
+
67
+ inner_dim = dim_head * heads
68
+ self.seq_len = seq_len
69
+ self.heads = heads
70
+ self.scale = dim_head ** -0.5
71
+ self.image_size = image_size
72
+ self.attn_drop = nn.Dropout(attn_dropout)
73
+
74
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
75
+
76
+ self.to_out = nn.Sequential(
77
+ nn.Linear(inner_dim, dim),
78
+ nn.Dropout(resid_dropout)
79
+ )
80
+
81
+ def forward(self, x):
82
+ b, n, _, h, img_size, axis, seq_len, device = *x.shape, self.heads, self.image_size, self.axis, self.seq_len, x.device
83
+
84
+ img_seq_len = img_size ** 2
85
+ text_len = seq_len + 1 - img_seq_len
86
+
87
+ # padding
88
+
89
+ padding = seq_len - n + 1
90
+ mask = torch.ones(b, text_len, device = device).bool()
91
+
92
+ x = F.pad(x, (0, 0, 0, padding), value = 0)
93
+ mask = mask[:, :text_len]
94
+
95
+ # derive queries / keys / values
96
+
97
+ qkv = self.to_qkv(x).chunk(3, dim = -1)
98
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)
99
+
100
+ # print(self.scale)
101
+ q = q * self.scale
102
+
103
+ ((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map(lambda t: (t[:, :-img_seq_len], t[:, -img_seq_len:]), (q, k, v))
104
+
105
+ # text attention
106
+
107
+ dots_text = einsum('b i d, b j d -> b i j', q_text, k_text)
108
+ mask_value = max_neg_value(dots_text)
109
+
110
+ i, j = dots_text.shape[-2:]
111
+ text_causal_mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
112
+ dots_text.masked_fill_(text_causal_mask, mask_value)
113
+
114
+ attn_text = torch.softmax(dots_text, dim = -1)
115
+
116
+ # attention dropout
117
+
118
+ attn_text = self.attn_drop(attn_text)
119
+ out_text = einsum('b i j, b j d -> b i d', attn_text, v_text)
120
+
121
+ # image attention
122
+
123
+ split_axis_einops = 'b (h w) c -> b h w c' if axis == 0 else 'b (h w) c -> b w h c'
124
+ merge_axis_einops = 'b x n d -> b (x n) d' if axis == 0 else 'b x n d -> b (n x) d'
125
+
126
+ # split out axis
127
+
128
+ q_img, k_img, v_img = map(lambda t: rearrange(t, split_axis_einops, h = img_size), (q_img, k_img, v_img))
129
+
130
+ # similarity
131
+
132
+ dots_image_to_image = einsum('b x i d, b x j d -> b x i j', q_img, k_img)
133
+ dots_image_to_text = einsum('b x i d, b j d -> b x i j', q_img, k_text)
134
+
135
+ dots = torch.cat((dots_image_to_text, dots_image_to_image), dim = -1)
136
+
137
+ # mask so image has full attention to text, but causal along axis
138
+
139
+ bh, x, i, j = dots.shape
140
+ causal_mask = torch.ones(i, img_size, device = device).triu_(img_size - i + 1).bool()
141
+ causal_mask = repeat(causal_mask, 'i j -> b x i j', b = bh, x = x)
142
+
143
+ mask = repeat(mask, 'b j -> (b h) x i j', h = h, x = x, i = i)
144
+ mask = torch.cat((~mask, causal_mask), dim = -1)
145
+
146
+ dots.masked_fill_(mask, mask_value)
147
+
148
+ # attention.
149
+
150
+ attn = torch.softmax(dots, dim = -1)
151
+
152
+ # attention dropout
153
+
154
+ attn = self.attn_drop(attn)
155
+
156
+ # aggregate
157
+
158
+ attn_image_to_text, attn_image_to_image = attn[..., :text_len], attn[..., text_len:]
159
+
160
+ out_image_to_image = einsum('b x i j, b x j d -> b x i d', attn_image_to_image, v_img)
161
+ out_image_to_text = einsum('b x i j, b j d -> b x i d', attn_image_to_text, v_text)
162
+
163
+ out_image = out_image_to_image + out_image_to_text
164
+
165
+ # merge back axis
166
+
167
+ out_image = rearrange(out_image, merge_axis_einops, x = img_size)
168
+
169
+ # combine attended values for both text and image
170
+
171
+ out = torch.cat((out_text, out_image), dim = 1)
172
+
173
+ out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
174
+ out = self.to_out(out)
175
+ return out[:, :n]
text2punks/data/byte-level-bpe_4k.tokenizer.json ADDED
@@ -0,0 +1,969 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "1.0",
3
+ "truncation": null,
4
+ "padding": null,
5
+ "added_tokens": [
6
+ {
7
+ "id": 0,
8
+ "special": true,
9
+ "content": "[PAD]",
10
+ "single_word": false,
11
+ "lstrip": false,
12
+ "rstrip": false,
13
+ "normalized": false
14
+ },
15
+ {
16
+ "id": 1,
17
+ "special": true,
18
+ "content": "[SEP]",
19
+ "single_word": false,
20
+ "lstrip": false,
21
+ "rstrip": false,
22
+ "normalized": false
23
+ }
24
+ ],
25
+ "normalizer": {
26
+ "type": "Lowercase"
27
+ },
28
+ "pre_tokenizer": {
29
+ "type": "ByteLevel",
30
+ "add_prefix_space": false,
31
+ "trim_offsets": true
32
+ },
33
+ "post_processor": {
34
+ "type": "ByteLevel",
35
+ "add_prefix_space": true,
36
+ "trim_offsets": true
37
+ },
38
+ "decoder": {
39
+ "type": "ByteLevel",
40
+ "add_prefix_space": true,
41
+ "trim_offsets": true
42
+ },
43
+ "model": {
44
+ "type": "BPE",
45
+ "dropout": null,
46
+ "unk_token": null,
47
+ "continuing_subword_prefix": null,
48
+ "end_of_word_suffix": null,
49
+ "fuse_unk": false,
50
+ "vocab": {
51
+ "[PAD]": 0,
52
+ "[SEP]": 1,
53
+ ",": 2,
54
+ "-": 3,
55
+ ".": 4,
56
+ "0": 5,
57
+ "1": 6,
58
+ "2": 7,
59
+ "3": 8,
60
+ "4": 9,
61
+ "5": 10,
62
+ "6": 11,
63
+ "7": 12,
64
+ "?": 13,
65
+ "a": 14,
66
+ "b": 15,
67
+ "c": 16,
68
+ "d": 17,
69
+ "e": 18,
70
+ "f": 19,
71
+ "g": 20,
72
+ "h": 21,
73
+ "i": 22,
74
+ "k": 23,
75
+ "l": 24,
76
+ "m": 25,
77
+ "n": 26,
78
+ "o": 27,
79
+ "p": 28,
80
+ "r": 29,
81
+ "s": 30,
82
+ "t": 31,
83
+ "u": 32,
84
+ "v": 33,
85
+ "w": 34,
86
+ "x": 35,
87
+ "y": 36,
88
+ "z": 37,
89
+ "Ċ": 38,
90
+ "Ġ": 39,
91
+ "Ġa": 40,
92
+ "nd": 41,
93
+ "Ġb": 42,
94
+ "ha": 43,
95
+ "le": 44,
96
+ "ma": 45,
97
+ "Ġc": 46,
98
+ "ro": 47,
99
+ "pu": 48,
100
+ "ck": 49,
101
+ "to": 50,
102
+ "Ġand": 51,
103
+ "ack": 52,
104
+ "ar": 53,
105
+ "Ġma": 54,
106
+ "nk": 55,
107
+ "gro": 56,
108
+ "und": 57,
109
+ "Ġback": 58,
110
+ "punk": 59,
111
+ "ground": 60,
112
+ "Ġbackground": 61,
113
+ "Ġha": 62,
114
+ "Ġcr": 63,
115
+ "in": 64,
116
+ "Ġs": 65,
117
+ "Ġo": 66,
118
+ "pto": 67,
119
+ "ypto": 68,
120
+ "Ġcrypto": 69,
121
+ "Ġcryptopunk": 70,
122
+ "Ġof": 71,
123
+ "es": 72,
124
+ "Ġw": 73,
125
+ "Ġwi": 74,
126
+ "th": 75,
127
+ "ing": 76,
128
+ "ho": 77,
129
+ "lo": 78,
130
+ "Ġwith": 79,
131
+ "Ġp": 80,
132
+ "Ġmale": 81,
133
+ "Ġg": 82,
134
+ "Ġf": 83,
135
+ "ear": 84,
136
+ "Ġhas": 85,
137
+ "Ġm": 86,
138
+ "lu": 87,
139
+ "Ġt": 88,
140
+ "re": 89,
141
+ "Ġfe": 90,
142
+ "de": 91,
143
+ "wn": 92,
144
+ "ok": 93,
145
+ "hoto": 94,
146
+ "look": 95,
147
+ "Ġphoto": 96,
148
+ "hat": 97,
149
+ "Ġthat": 98,
150
+ "Ġmade": 99,
151
+ "male": 100,
152
+ "Ġfemale": 101,
153
+ "tr": 102,
154
+ "ir": 103,
155
+ "Ġhair": 104,
156
+ "Ġsha": 105,
157
+ "ple": 106,
158
+ "rple": 107,
159
+ "Ġpu": 108,
160
+ "Ġpurple": 109,
161
+ "ut": 110,
162
+ "en": 111,
163
+ "Ġgre": 112,
164
+ "Ġgreen": 113,
165
+ "Ġshad": 114,
166
+ "Ġblu": 115,
167
+ "Ġblue": 116,
168
+ "Ġmo": 117,
169
+ "looking": 118,
170
+ "li": 119,
171
+ "Ġr": 120,
172
+ "rown": 121,
173
+ "ti": 122,
174
+ "Ġli": 123,
175
+ "la": 124,
176
+ "ap": 125,
177
+ "Ġbear": 126,
178
+ "Ġbeard": 127,
179
+ "are": 128,
180
+ "Ġbrown": 129,
181
+ "wk": 130,
182
+ "hawk": 131,
183
+ "Ġmohawk": 132,
184
+ "ig": 133,
185
+ "ring": 134,
186
+ "Ġear": 135,
187
+ "Ġearring": 136,
188
+ "ed": 137,
189
+ "ey": 138,
190
+ "Ġey": 139,
191
+ "but": 140,
192
+ "ibut": 141,
193
+ "ttr": 142,
194
+ "Ġattr": 143,
195
+ "punky": 144,
196
+ "Ġattribut": 145,
197
+ "ps": 146,
198
+ "Ġattributes": 147,
199
+ "Ġcap": 148,
200
+ "ss": 149,
201
+ "tick": 150,
202
+ "Ġlips": 151,
203
+ "Ġlipstick": 152,
204
+ "Ġshades": 153,
205
+ "lass": 154,
206
+ "Ġn": 155,
207
+ "Ġeye": 156,
208
+ "Ġpunky": 157,
209
+ "Ġrare": 158,
210
+ "Ġho": 159,
211
+ "ow": 160,
212
+ "Ġglass": 161,
213
+ "Ġglasses": 162,
214
+ "tt": 163,
215
+ "Ġshadow": 164,
216
+ "Ġ3": 165,
217
+ "Ġgo": 166,
218
+ "or": 167,
219
+ "Ġd": 168,
220
+ "mal": 169,
221
+ "Ġpi": 170,
222
+ "Ġstr": 171,
223
+ "he": 172,
224
+ "Ġclo": 173,
225
+ "Ġclown": 174,
226
+ "ke": 175,
227
+ "od": 176,
228
+ "ark": 177,
229
+ "Ġdark": 178,
230
+ "ld": 179,
231
+ "ce": 180,
232
+ "Ġcig": 181,
233
+ "arett": 182,
234
+ "Ġcigarett": 183,
235
+ "Ġcigarette": 184,
236
+ "lack": 185,
237
+ "Ġblack": 186,
238
+ "and": 187,
239
+ "Ġnor": 188,
240
+ "Ġnormal": 189,
241
+ "Ġ2": 190,
242
+ "nt": 191,
243
+ "ront": 192,
244
+ "Ġfront": 193,
245
+ "Ġlooking": 194,
246
+ "car": 195,
247
+ "cut": 196,
248
+ "ela": 197,
249
+ "on": 198,
250
+ "olu": 199,
251
+ "ted": 200,
252
+ "up": 201,
253
+ "xela": 202,
254
+ "Ġlo": 203,
255
+ "Ġlook": 204,
256
+ "Ġup": 205,
257
+ "Ġsing": 206,
258
+ "Ġscar": 207,
259
+ "esolu": 208,
260
+ "how": 209,
261
+ "Ġresolu": 210,
262
+ "tion": 211,
263
+ "Ġlike": 212,
264
+ "Ġgood": 213,
265
+ "Ġpixela": 214,
266
+ "cute": 215,
267
+ "Ġlow": 216,
268
+ "Ġsingle": 217,
269
+ "Ġscarce": 218,
270
+ "Ġresolution": 219,
271
+ "Ġpixelated": 220,
272
+ "fu": 221,
273
+ "nn": 222,
274
+ "funn": 223,
275
+ "funny": 224,
276
+ "Ġeyes": 225,
277
+ "Ġhe": 226,
278
+ "at": 227,
279
+ "Ġv": 228,
280
+ "aig": 229,
281
+ "ht": 230,
282
+ "Ġstraig": 231,
283
+ "Ġstraight": 232,
284
+ "er": 233,
285
+ "Ġwild": 234,
286
+ "ad": 235,
287
+ "Ġhead": 236,
288
+ "Ġhot": 237,
289
+ "Ġbig": 238,
290
+ "ic": 239,
291
+ "Ġre": 240,
292
+ "Ġmole": 241,
293
+ "an": 242,
294
+ "mp": 243,
295
+ "sy": 244,
296
+ "us": 245,
297
+ "Ġner": 246,
298
+ "Ġnerd": 247,
299
+ "nde": 248,
300
+ "Ġblo": 249,
301
+ "Ġblonde": 250,
302
+ "im": 251,
303
+ "ned": 252,
304
+ "rned": 253,
305
+ "Ġrim": 254,
306
+ "Ġhorned": 255,
307
+ "Ġhat": 256,
308
+ "gu": 257,
309
+ "lar": 258,
310
+ "Ġregu": 259,
311
+ "Ġregular": 260,
312
+ "Ġclass": 261,
313
+ "Ġclassic": 262,
314
+ "Ġband": 263,
315
+ "ana": 264,
316
+ "Ġbandana": 265,
317
+ "sk": 266,
318
+ "Ġmask": 267,
319
+ "ingy": 268,
320
+ "Ġstringy": 269,
321
+ "ch": 270,
322
+ "Ġpat": 271,
323
+ "Ġpatch": 272,
324
+ "essy": 273,
325
+ "Ġmessy": 274,
326
+ "ved": 275,
327
+ "Ġshaved": 276,
328
+ "ru": 277,
329
+ "Ġfru": 278,
330
+ "mpy": 279,
331
+ "Ġfrumpy": 280,
332
+ "Ġth": 281,
333
+ "Ġthin": 282,
334
+ "Ġsp": 283,
335
+ "itt": 284,
336
+ "kn": 285,
337
+ "Ġkn": 286,
338
+ "itted": 287,
339
+ "Ġknitted": 288,
340
+ "az": 289,
341
+ "Ġcraz": 290,
342
+ "Ġcrazy": 291,
343
+ "band": 292,
344
+ "Ġheadband": 293,
345
+ "ie": 294,
346
+ "ta": 295,
347
+ "Ġsmal": 296,
348
+ "Ġsmall": 297,
349
+ "pe": 298,
350
+ "Ġvr": 299,
351
+ "Ġ4": 300,
352
+ "hain": 301,
353
+ "Ġchain": 302,
354
+ "Ġpipe": 303,
355
+ "ak": 304,
356
+ "cho": 305,
357
+ "eak": 306,
358
+ "ike": 307,
359
+ "ncho": 308,
360
+ "toncho": 309,
361
+ "Ġpeak": 310,
362
+ "Ġmut": 311,
363
+ "Ġspike": 312,
364
+ "tonchops": 313,
365
+ "Ġmuttonchops": 314,
366
+ "ag": 315,
367
+ "rag": 316,
368
+ "Ġdo": 317,
369
+ "Ġgoat": 318,
370
+ "che": 319,
371
+ "Ġmus": 320,
372
+ "tache": 321,
373
+ "Ġmustache": 322,
374
+ "ur": 323,
375
+ "io": 324,
376
+ "xur": 325,
377
+ "Ġlu": 326,
378
+ "ious": 327,
379
+ "xurious": 328,
380
+ "Ġluxurious": 329,
381
+ "hin": 330,
382
+ "str": 331,
383
+ "Ġchin": 332,
384
+ "strap": 333,
385
+ "Ġchinstrap": 334,
386
+ "ape": 335,
387
+ "Ġvape": 336,
388
+ "bar": 337,
389
+ "ndle": 338,
390
+ "Ġhandle": 339,
391
+ "bars": 340,
392
+ "Ġhandlebars": 341,
393
+ "Ġfrown": 342,
394
+ "Ġhood": 343,
395
+ "Ġhoodie": 344,
396
+ "war": 345,
397
+ "Ġfor": 346,
398
+ "ward": 347,
399
+ "Ġforward": 348,
400
+ "il": 349,
401
+ "ile": 350,
402
+ "mile": 351,
403
+ "Ġsmile": 352,
404
+ "Ġno": 353,
405
+ "se": 354,
406
+ "Ġnose": 355,
407
+ "oli": 356,
408
+ "Ġpoli": 357,
409
+ "Ġpolice": 358,
410
+ "dor": 359,
411
+ "Ġfedor": 360,
412
+ "Ġfedora": 361,
413
+ "ass": 362,
414
+ "Ġtass": 363,
415
+ "Ġtassle": 364,
416
+ "al": 365,
417
+ "Ġmed": 366,
418
+ "ical": 367,
419
+ "Ġmedical": 368,
420
+ "Ġgold": 369,
421
+ "ver": 370,
422
+ "Ġsil": 371,
423
+ "Ġsilver": 372,
424
+ "amp": 373,
425
+ "ire": 374,
426
+ "lf": 375,
427
+ "ob": 376,
428
+ "Ġbob": 377,
429
+ "Ġhalf": 378,
430
+ "Ġvamp": 379,
431
+ "Ġred": 380,
432
+ "Ġvampire": 381,
433
+ "bo": 382,
434
+ "Ġcow": 383,
435
+ "boy": 384,
436
+ "Ġcowboy": 385,
437
+ "hi": 386,
438
+ "te": 387,
439
+ "Ġwhi": 388,
440
+ "Ġwhite": 389,
441
+ "rt": 390,
442
+ "Ġsho": 391,
443
+ "Ġshort": 392,
444
+ "ek": 393,
445
+ "Ġro": 394,
446
+ "Ġche": 395,
447
+ "eks": 396,
448
+ "Ġrosy": 397,
449
+ "Ġcheeks": 398,
450
+ "ot": 399,
451
+ "Ġspot": 400,
452
+ "Ġspots": 401,
453
+ "Ġto": 402,
454
+ "Ġtop": 403,
455
+ "Ġpink": 404,
456
+ "Ġpig": 405,
457
+ "tail": 406,
458
+ "Ġpigtail": 407,
459
+ "Ġpigtails": 408,
460
+ "Ġz": 409,
461
+ "bie": 410,
462
+ "mbie": 411,
463
+ "ombie": 412,
464
+ "Ġzombie": 413,
465
+ "eld": 414,
466
+ "gg": 415,
467
+ "les": 416,
468
+ "Ġweld": 417,
469
+ "Ġgogg": 418,
470
+ "Ġwelding": 419,
471
+ "Ġgoggles": 420,
472
+ "ee": 421,
473
+ "uck": 422,
474
+ "Ġbuck": 423,
475
+ "Ġtee": 424,
476
+ "Ġteeth": 425,
477
+ "Ġ1": 426,
478
+ "ge": 427,
479
+ "ide": 428,
480
+ "ran": 429,
481
+ "Ġside": 430,
482
+ "Ġoran": 431,
483
+ "Ġorange": 432,
484
+ "Ġattribute": 433,
485
+ "iar": 434,
486
+ "Ġtiar": 435,
487
+ "Ġtiara": 436,
488
+ "et": 437,
489
+ "lm": 438,
490
+ "lot": 439,
491
+ "Ġpilot": 440,
492
+ "Ġhelm": 441,
493
+ "Ġhelmet": 442,
494
+ "Ġcho": 443,
495
+ "ker": 444,
496
+ "Ġchoker": 445,
497
+ "ean": 446,
498
+ "Ġbean": 447,
499
+ "Ġbeanie": 448,
500
+ "Ġ5": 449,
501
+ "Ġape": 450,
502
+ "Ġali": 451,
503
+ "Ġalien": 452,
504
+ "Ġ6": 453,
505
+ "imple": 454,
506
+ "Ġ0": 455,
507
+ "Ġsimple": 456,
508
+ "Ġfeat": 457,
509
+ "ures": 458,
510
+ "Ġfeatures": 459,
511
+ "ace": 460,
512
+ "ase": 461,
513
+ "ero": 462,
514
+ "lin": 463,
515
+ "simple": 464,
516
+ "Ġaver": 465,
517
+ "Ġbut": 466,
518
+ "Ġbare": 467,
519
+ "Ġbase": 468,
520
+ "thing": 469,
521
+ "Ġface": 470,
522
+ "age": 471,
523
+ "Ġnothing": 472,
524
+ "Ġzero": 473,
525
+ "line": 474,
526
+ "Ġaverage": 475,
527
+ "Ġ7": 476
528
+ },
529
+ "merges": [
530
+ "Ġ a",
531
+ "n d",
532
+ "Ġ b",
533
+ "h a",
534
+ "l e",
535
+ "m a",
536
+ "Ġ c",
537
+ "r o",
538
+ "p u",
539
+ "c k",
540
+ "t o",
541
+ "Ġa nd",
542
+ "a ck",
543
+ "a r",
544
+ "Ġ ma",
545
+ "n k",
546
+ "g ro",
547
+ "u nd",
548
+ "Ġb ack",
549
+ "pu nk",
550
+ "gro und",
551
+ "Ġback ground",
552
+ "Ġ ha",
553
+ "Ġc r",
554
+ "i n",
555
+ "Ġ s",
556
+ "Ġ o",
557
+ "p to",
558
+ "y pto",
559
+ "Ġcr ypto",
560
+ "Ġcrypto punk",
561
+ "Ġo f",
562
+ "e s",
563
+ "Ġ w",
564
+ "Ġw i",
565
+ "t h",
566
+ "in g",
567
+ "h o",
568
+ "l o",
569
+ "Ġwi th",
570
+ "Ġ p",
571
+ "Ġma le",
572
+ "Ġ g",
573
+ "Ġ f",
574
+ "e ar",
575
+ "Ġha s",
576
+ "Ġ m",
577
+ "l u",
578
+ "Ġ t",
579
+ "r e",
580
+ "Ġf e",
581
+ "d e",
582
+ "w n",
583
+ "o k",
584
+ "ho to",
585
+ "lo ok",
586
+ "Ġp hoto",
587
+ "ha t",
588
+ "Ġt hat",
589
+ "Ġma de",
590
+ "ma le",
591
+ "Ġfe male",
592
+ "t r",
593
+ "i r",
594
+ "Ġha ir",
595
+ "Ġs ha",
596
+ "p le",
597
+ "r ple",
598
+ "Ġ pu",
599
+ "Ġpu rple",
600
+ "u t",
601
+ "e n",
602
+ "Ġg re",
603
+ "Ġgre en",
604
+ "Ġsha d",
605
+ "Ġb lu",
606
+ "Ġblu e",
607
+ "Ġm o",
608
+ "look ing",
609
+ "l i",
610
+ "Ġ r",
611
+ "ro wn",
612
+ "t i",
613
+ "Ġ li",
614
+ "l a",
615
+ "a p",
616
+ "Ġb ear",
617
+ "Ġbear d",
618
+ "ar e",
619
+ "Ġb rown",
620
+ "w k",
621
+ "ha wk",
622
+ "Ġmo hawk",
623
+ "i g",
624
+ "r ing",
625
+ "Ġ ear",
626
+ "Ġear ring",
627
+ "e d",
628
+ "e y",
629
+ "Ġ ey",
630
+ "b ut",
631
+ "i but",
632
+ "t tr",
633
+ "Ġa ttr",
634
+ "punk y",
635
+ "Ġattr ibut",
636
+ "p s",
637
+ "Ġattribut es",
638
+ "Ġc ap",
639
+ "s s",
640
+ "ti ck",
641
+ "Ġli ps",
642
+ "Ġlips tick",
643
+ "Ġshad es",
644
+ "la ss",
645
+ "Ġ n",
646
+ "Ġey e",
647
+ "Ġ punky",
648
+ "Ġr are",
649
+ "Ġ ho",
650
+ "o w",
651
+ "Ġg lass",
652
+ "Ġglass es",
653
+ "t t",
654
+ "Ġshad ow",
655
+ "Ġ 3",
656
+ "Ġg o",
657
+ "o r",
658
+ "Ġ d",
659
+ "ma l",
660
+ "Ġp i",
661
+ "Ġs tr",
662
+ "h e",
663
+ "Ġc lo",
664
+ "Ġclo wn",
665
+ "k e",
666
+ "o d",
667
+ "ar k",
668
+ "Ġd ark",
669
+ "l d",
670
+ "c e",
671
+ "Ġc ig",
672
+ "are tt",
673
+ "Ġcig arett",
674
+ "Ġcigarett e",
675
+ "l ack",
676
+ "Ġb lack",
677
+ "a nd",
678
+ "Ġn or",
679
+ "Ġnor mal",
680
+ "Ġ 2",
681
+ "n t",
682
+ "ro nt",
683
+ "Ġf ront",
684
+ "Ġ looking",
685
+ "c ar",
686
+ "c ut",
687
+ "e la",
688
+ "o n",
689
+ "o lu",
690
+ "t ed",
691
+ "u p",
692
+ "x ela",
693
+ "Ġ lo",
694
+ "Ġ look",
695
+ "Ġ up",
696
+ "Ġs ing",
697
+ "Ġs car",
698
+ "es olu",
699
+ "ho w",
700
+ "Ġr esolu",
701
+ "ti on",
702
+ "Ġli ke",
703
+ "Ġgo od",
704
+ "Ġpi xela",
705
+ "cut e",
706
+ "Ġlo w",
707
+ "Ġsing le",
708
+ "Ġscar ce",
709
+ "Ġresolu tion",
710
+ "Ġpixela ted",
711
+ "f u",
712
+ "n n",
713
+ "fu nn",
714
+ "funn y",
715
+ "Ġey es",
716
+ "Ġ he",
717
+ "a t",
718
+ "Ġ v",
719
+ "a ig",
720
+ "h t",
721
+ "Ġstr aig",
722
+ "Ġstraig ht",
723
+ "e r",
724
+ "Ġwi ld",
725
+ "a d",
726
+ "Ġhe ad",
727
+ "Ġho t",
728
+ "Ġb ig",
729
+ "i c",
730
+ "Ġ re",
731
+ "Ġmo le",
732
+ "a n",
733
+ "m p",
734
+ "s y",
735
+ "u s",
736
+ "Ġn er",
737
+ "Ġner d",
738
+ "nd e",
739
+ "Ġb lo",
740
+ "Ġblo nde",
741
+ "i m",
742
+ "n ed",
743
+ "r ned",
744
+ "Ġr im",
745
+ "Ġho rned",
746
+ "Ġha t",
747
+ "g u",
748
+ "l ar",
749
+ "Ġre gu",
750
+ "Ġregu lar",
751
+ "Ġc lass",
752
+ "Ġclass ic",
753
+ "Ġb and",
754
+ "an a",
755
+ "Ġband ana",
756
+ "s k",
757
+ "Ġma sk",
758
+ "ing y",
759
+ "Ġstr ingy",
760
+ "c h",
761
+ "Ġp at",
762
+ "Ġpat ch",
763
+ "es sy",
764
+ "Ġm essy",
765
+ "v ed",
766
+ "Ġsha ved",
767
+ "r u",
768
+ "Ġf ru",
769
+ "mp y",
770
+ "Ġfru mpy",
771
+ "Ġ th",
772
+ "Ġth in",
773
+ "Ġs p",
774
+ "i tt",
775
+ "k n",
776
+ "Ġ kn",
777
+ "itt ed",
778
+ "Ġkn itted",
779
+ "a z",
780
+ "Ġcr az",
781
+ "Ġcraz y",
782
+ "b and",
783
+ "Ġhead band",
784
+ "i e",
785
+ "t a",
786
+ "Ġs mal",
787
+ "Ġsmal l",
788
+ "p e",
789
+ "Ġv r",
790
+ "Ġ 4",
791
+ "ha in",
792
+ "Ġc hain",
793
+ "Ġpi pe",
794
+ "a k",
795
+ "c ho",
796
+ "e ak",
797
+ "i ke",
798
+ "n cho",
799
+ "to ncho",
800
+ "Ġp eak",
801
+ "Ġm ut",
802
+ "Ġsp ike",
803
+ "toncho ps",
804
+ "Ġmut tonchops",
805
+ "a g",
806
+ "r ag",
807
+ "Ġd o",
808
+ "Ġgo at",
809
+ "c he",
810
+ "Ġm us",
811
+ "ta che",
812
+ "Ġmus tache",
813
+ "u r",
814
+ "i o",
815
+ "x ur",
816
+ "Ġ lu",
817
+ "io us",
818
+ "xur ious",
819
+ "Ġlu xurious",
820
+ "h in",
821
+ "s tr",
822
+ "Ġc hin",
823
+ "str ap",
824
+ "Ġchin strap",
825
+ "ap e",
826
+ "Ġv ape",
827
+ "b ar",
828
+ "nd le",
829
+ "Ġha ndle",
830
+ "bar s",
831
+ "Ġhandle bars",
832
+ "Ġf rown",
833
+ "Ġho od",
834
+ "Ġhood ie",
835
+ "w ar",
836
+ "Ġf or",
837
+ "war d",
838
+ "Ġfor ward",
839
+ "i l",
840
+ "i le",
841
+ "m ile",
842
+ "Ġs mile",
843
+ "Ġn o",
844
+ "s e",
845
+ "Ġno se",
846
+ "o li",
847
+ "Ġp oli",
848
+ "Ġpoli ce",
849
+ "d or",
850
+ "Ġfe dor",
851
+ "Ġfedor a",
852
+ "a ss",
853
+ "Ġt ass",
854
+ "Ġtass le",
855
+ "a l",
856
+ "Ġm ed",
857
+ "ic al",
858
+ "Ġmed ical",
859
+ "Ġgo ld",
860
+ "v er",
861
+ "Ġs il",
862
+ "Ġsil ver",
863
+ "a mp",
864
+ "i re",
865
+ "l f",
866
+ "o b",
867
+ "Ġb ob",
868
+ "Ġha lf",
869
+ "Ġv amp",
870
+ "Ġre d",
871
+ "Ġvamp ire",
872
+ "b o",
873
+ "Ġc ow",
874
+ "bo y",
875
+ "Ġcow boy",
876
+ "h i",
877
+ "t e",
878
+ "Ġw hi",
879
+ "Ġwhi te",
880
+ "r t",
881
+ "Ġs ho",
882
+ "Ġsho rt",
883
+ "e k",
884
+ "Ġ ro",
885
+ "Ġc he",
886
+ "ek s",
887
+ "Ġro sy",
888
+ "Ġche eks",
889
+ "o t",
890
+ "Ġsp ot",
891
+ "Ġspot s",
892
+ "Ġ to",
893
+ "Ġto p",
894
+ "Ġpi nk",
895
+ "Ġp ig",
896
+ "ta il",
897
+ "Ġpig tail",
898
+ "Ġpigtail s",
899
+ "Ġ z",
900
+ "b ie",
901
+ "m bie",
902
+ "o mbie",
903
+ "Ġz ombie",
904
+ "e ld",
905
+ "g g",
906
+ "le s",
907
+ "Ġw eld",
908
+ "Ġgo gg",
909
+ "Ġweld ing",
910
+ "Ġgogg les",
911
+ "e e",
912
+ "u ck",
913
+ "Ġb uck",
914
+ "Ġt ee",
915
+ "Ġtee th",
916
+ "Ġ 1",
917
+ "g e",
918
+ "i de",
919
+ "r an",
920
+ "Ġs ide",
921
+ "Ġo ran",
922
+ "Ġoran ge",
923
+ "Ġattribut e",
924
+ "i ar",
925
+ "Ġt iar",
926
+ "Ġtiar a",
927
+ "e t",
928
+ "l m",
929
+ "lo t",
930
+ "Ġpi lot",
931
+ "Ġhe lm",
932
+ "Ġhelm et",
933
+ "Ġc ho",
934
+ "ke r",
935
+ "Ġcho ker",
936
+ "e an",
937
+ "Ġb ean",
938
+ "Ġbean ie",
939
+ "Ġ 5",
940
+ "Ġa pe",
941
+ "Ġa li",
942
+ "Ġali en",
943
+ "Ġ 6",
944
+ "im ple",
945
+ "Ġ 0",
946
+ "Ġs imple",
947
+ "Ġfe at",
948
+ "ur es",
949
+ "Ġfeat ures",
950
+ "a ce",
951
+ "a se",
952
+ "e ro",
953
+ "l in",
954
+ "s imple",
955
+ "Ġa ver",
956
+ "Ġb ut",
957
+ "Ġb are",
958
+ "Ġb ase",
959
+ "th ing",
960
+ "Ġf ace",
961
+ "ag e",
962
+ "Ġno thing",
963
+ "Ġz ero",
964
+ "lin e",
965
+ "Ġaver age",
966
+ "Ġ 7"
967
+ ]
968
+ }
969
+ }
text2punks/data/codebook.pt ADDED
Binary file (1.39 kB). View file
 
text2punks/loader.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image, UnidentifiedImageError
3
+
4
+ from pathlib import Path
5
+ from random import randint, choice
6
+
7
+ import torch
8
+ from torch.utils.data import Dataset
9
+
10
+
11
+ class TextImageDataset(Dataset):
12
+ def __init__(self,
13
+ folder,
14
+ text_len=40,
15
+ truncate_captions=False,
16
+ text_tokenizer=None,
17
+ image_tokenizer=None,
18
+ shuffle=False
19
+ ):
20
+ """
21
+ @param folder: Folder containing images and text files matched by their paths' respective "stem"
22
+ @param truncate_captions: Rather than throw an exception, captions which are too long will be truncated.
23
+ """
24
+ super().__init__()
25
+ self.shuffle = shuffle
26
+ path = Path(folder)
27
+
28
+ text_files = [*path.glob('**/*.txt')]
29
+ image_files = [
30
+ *path.glob('**/*.png'), *path.glob('**/*.jpg'),
31
+ *path.glob('**/*.jpeg'), *path.glob('**/*.bmp')
32
+ ]
33
+
34
+ text_files = {text_file.stem: text_file for text_file in text_files}
35
+ image_files = {image_file.stem: image_file for image_file in image_files}
36
+
37
+ keys = (image_files.keys() & text_files.keys())
38
+
39
+ self.keys = list(keys)
40
+ self.text_files = {k: v for k, v in text_files.items() if k in keys}
41
+ self.image_files = {k: v for k, v in image_files.items() if k in keys}
42
+ self.text_len = text_len
43
+ self.truncate_captions = truncate_captions
44
+ self.text_tokenizer = text_tokenizer
45
+ self.image_tokenizer = image_tokenizer
46
+
47
+
48
+ def __len__(self):
49
+ return len(self.keys)
50
+
51
+ def random_sample(self):
52
+ return self.__getitem__(randint(0, self.__len__() - 1))
53
+
54
+ def sequential_sample(self, ind):
55
+ if ind >= self.__len__() - 1:
56
+ return self.__getitem__(0)
57
+ return self.__getitem__(ind + 1)
58
+
59
+ def skip_sample(self, ind):
60
+ if self.shuffle:
61
+ return self.random_sample()
62
+ return self.sequential_sample(ind=ind)
63
+
64
+ def __getitem__(self, ind):
65
+ key = self.keys[ind]
66
+
67
+ text_file = self.text_files[key]
68
+ image_file = self.image_files[key]
69
+
70
+ descriptions = text_file.read_text().split('\n')
71
+ descriptions = list(filter(lambda t: len(t) > 0, descriptions))
72
+ try:
73
+ description = choice(descriptions)
74
+ except IndexError as zero_captions_in_file_ex:
75
+ print(f"An exception occurred trying to load file {text_file}.")
76
+ print(f"Skipping index {ind}")
77
+ return self.skip_sample(ind)
78
+
79
+ tokenized_text = self.text_tokenizer.tokenize(
80
+ description,
81
+ self.text_len,
82
+ truncate_text=self.truncate_captions
83
+ ).squeeze(0)
84
+ try:
85
+ image = Image.open(image_file).convert('RGB')
86
+ pixels = np.array(image).reshape(-1, 3)
87
+
88
+ tokenized_image = [self.image_tokenizer[str(idx)] for idx in pixels]
89
+ tokenized_image = torch.tensor(tokenized_image)
90
+ except (UnidentifiedImageError, OSError) as corrupt_image_exceptions:
91
+ print(f"An exception occurred trying to load file {image_file}.")
92
+ print(f"Skipping index {ind}")
93
+ return self.skip_sample(ind)
94
+
95
+ # Success
96
+ return tokenized_text, tokenized_image
text2punks/text2punk.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ from einops import rearrange, repeat
4
+
5
+ import torch
6
+ from torch import nn, einsum
7
+ import torch.nn.functional as F
8
+
9
+ from axial_positional_embedding import AxialPositionalEmbedding
10
+ from text2punks.transformer import Transformer
11
+
12
+
13
+ # helpers fns
14
+
15
+ def exists(val):
16
+ return val is not None
17
+
18
+ def default(val, d):
19
+ return val if exists(val) else d
20
+
21
+ def set_requires_grad(model, value):
22
+ for param in model.parameters():
23
+ param.requires_grad = value
24
+
25
+ def eval_decorator(fn):
26
+ def inner(model, *args, **kwargs):
27
+ was_training = model.training
28
+ model.eval()
29
+ out = fn(model, *args, **kwargs)
30
+ model.train(was_training)
31
+ return out
32
+ return inner
33
+
34
+ # sampling helpers fn
35
+
36
+ def top_k(logits, thres = 0.5):
37
+ num_logits = logits.shape[-1]
38
+ k = max(int((1 - thres) * num_logits), 1)
39
+ val, ind = torch.topk(logits, k)
40
+ probs = torch.full_like(logits, float('-inf'))
41
+ probs.scatter_(1, ind, val)
42
+ return probs
43
+
44
+ # main CLIP class
45
+
46
+ class CLIP(nn.Module):
47
+ def __init__(
48
+ self,
49
+ *,
50
+ dim_text = 512,
51
+ dim_image = 512,
52
+ dim_latent = 512,
53
+ num_text_tokens = 10000,
54
+ text_enc_depth = 6,
55
+ text_seq_len = 256,
56
+ text_heads = 8,
57
+ num_visual_tokens = 256,
58
+ visual_enc_depth = 6,
59
+ visual_image_seq_len = 256,
60
+ visual_image_size = 24,
61
+ visual_heads = 8,
62
+ attn_pdrop = 0.1,
63
+ resid_pdrop = 0.1,
64
+ embd_pdrop = 0.1,
65
+ ff_dropout = 0.1,
66
+ attn_types = None
67
+ ):
68
+ super().__init__()
69
+
70
+ # Texts
71
+
72
+ self.text_emb = nn.Embedding(num_text_tokens, dim_text)
73
+ self.text_pos_emb = nn.Embedding(text_seq_len, dim_text)
74
+
75
+ self.text_transformer = Transformer(
76
+ dim = dim_text,
77
+ causal = False,
78
+ seq_len = text_seq_len,
79
+ depth = text_enc_depth,
80
+ heads = text_heads,
81
+ dim_head = dim_text // text_heads,
82
+ attn_dropout = attn_pdrop,
83
+ resid_dropout = resid_pdrop,
84
+ embd_dropout = embd_pdrop,
85
+ ff_dropout = ff_dropout,
86
+ attn_types = attn_types
87
+ )
88
+
89
+ self.text_ln = nn.LayerNorm(dim_text)
90
+ self.to_text_latent = nn.Linear(dim_text, dim_latent, bias = False)
91
+
92
+ # Images
93
+
94
+ self.image_emb = nn.Embedding(num_visual_tokens, dim_image)
95
+ self.image_pos_emb = nn.Embedding(visual_image_seq_len, dim_image)
96
+
97
+ self.visual_transformer = Transformer(
98
+ dim = dim_image,
99
+ causal = False,
100
+ seq_len = visual_image_seq_len,
101
+ depth = visual_enc_depth,
102
+ heads = visual_heads,
103
+ dim_head = dim_image // visual_heads,
104
+ attn_dropout = attn_pdrop,
105
+ resid_dropout = resid_pdrop,
106
+ embd_dropout = embd_pdrop,
107
+ ff_dropout = ff_dropout,
108
+ attn_types = attn_types,
109
+ image_size = visual_image_size,
110
+ )
111
+
112
+ self.image_ln = nn.LayerNorm(dim_image)
113
+ self.to_visual_latent = nn.Linear(dim_image, dim_latent, bias = False)
114
+
115
+ self.temperature = nn.Parameter(torch.ones([]) * math.log(1 / 0.07))
116
+
117
+
118
+ self.apply(self._init_weights)
119
+
120
+ def _init_weights(self, module):
121
+ if isinstance(module, (nn.Linear, nn.Embedding)):
122
+ module.weight.data.normal_(mean=0.0, std=0.02)
123
+ if isinstance(module, nn.Linear) and module.bias is not None:
124
+ module.bias.data.zero_()
125
+ elif isinstance(module, nn.LayerNorm):
126
+ module.bias.data.zero_()
127
+ module.weight.data.fill_(1.0)
128
+
129
+ def forward(
130
+ self,
131
+ text,
132
+ image,
133
+ return_loss = False
134
+ ):
135
+ b, device= text.shape[0], text.device
136
+
137
+ text_emb = self.text_emb(text)
138
+ text_emb += self.text_pos_emb(torch.arange(text.shape[1], device = device))
139
+
140
+ image_emb = self.image_emb(image)
141
+ image_emb += self.image_pos_emb(torch.arange(image.shape[1], device = device))
142
+
143
+ enc_text = self.text_transformer(text_emb)
144
+ enc_image = self.visual_transformer(image_emb)
145
+
146
+ text_latents = enc_text.mean(dim = 1)
147
+ image_latents = enc_image.mean(dim = 1)
148
+
149
+ text_latents = self.text_ln(text_latents)
150
+ image_latents = self.image_ln(image_latents)
151
+
152
+ text_latents = self.to_text_latent(text_latents)
153
+ image_latents = self.to_visual_latent(image_latents)
154
+
155
+ text_latents, image_latents = map(lambda t: F.normalize(t, p = 2, dim = -1), (text_latents, image_latents))
156
+
157
+ temp = self.temperature.exp()
158
+
159
+ if not return_loss:
160
+ sim = einsum('n d, n d -> n', text_latents, image_latents) * temp
161
+ return sim
162
+
163
+ sim = einsum('i d, j d -> i j', text_latents, image_latents) * temp
164
+ labels = torch.arange(b, device = device)
165
+ loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
166
+ return loss
167
+
168
+ # main Text2Punks class
169
+
170
+ class Text2Punks(nn.Module):
171
+ def __init__(
172
+ self,
173
+ *,
174
+ n_embd,
175
+ n_layer = 12,
176
+ n_head = 12,
177
+ d_head = 64,
178
+ num_text_tokens = 10000,
179
+ text_seq_len = 256,
180
+ num_image_tokens = 222,
181
+ image_seq_len = 576,
182
+ image_size = 24,
183
+ attn_pdrop = 0.1,
184
+ resid_pdrop = 0.1,
185
+ embd_pdrop = 0.1,
186
+ ff_dropout = 0.1,
187
+ attn_types = None,
188
+ loss_img_weight = 7,
189
+ loss_txt_weight = 7,
190
+ ):
191
+ super().__init__()
192
+
193
+ num_text_tokens = num_text_tokens + text_seq_len # reserve unique padding tokens for each position (text seq len)
194
+
195
+ self.text_emb = nn.Embedding(num_text_tokens, n_embd)
196
+ self.image_emb = nn.Embedding(num_image_tokens, n_embd)
197
+
198
+ self.text_pos_emb = nn.Embedding(text_seq_len + 1, n_embd) # +1 for <bos> a.k.a <sos>
199
+ # self.image_pos_emb = nn.Embedding(image_seq_len, n_embd)
200
+ self.image_pos_emb = nn.Parameter(torch.zeros(1, image_seq_len, n_embd))
201
+ # self.image_pos_emb = AxialPositionalEmbedding(n_embd, axial_shape=(image_size, image_size))
202
+
203
+ self.num_text_tokens = num_text_tokens # for offsetting logits index and calculating cross entropy loss
204
+ self.num_image_tokens = num_image_tokens
205
+ self.text_seq_len = text_seq_len
206
+ self.image_seq_len = image_seq_len
207
+
208
+ seq_len = text_seq_len + image_seq_len
209
+ total_tokens = num_text_tokens + num_image_tokens
210
+ self.total_seq_len = seq_len
211
+ self.total_tokens = total_tokens
212
+
213
+ self.transformer = Transformer(
214
+ dim = n_embd,
215
+ causal = True,
216
+ seq_len = seq_len,
217
+ depth = n_layer,
218
+ heads = n_head,
219
+ dim_head = d_head,
220
+ attn_dropout = attn_pdrop,
221
+ resid_dropout = resid_pdrop,
222
+ embd_dropout = embd_pdrop,
223
+ ff_dropout = ff_dropout,
224
+ attn_types = attn_types,
225
+ image_size = image_size,
226
+ )
227
+
228
+ self.to_logits = nn.Sequential(
229
+ nn.LayerNorm(n_embd),
230
+ nn.Linear(n_embd, self.total_tokens),
231
+ )
232
+
233
+ seq_range = torch.arange(seq_len)
234
+ logits_range = torch.arange(total_tokens)
235
+
236
+ seq_range = rearrange(seq_range, 'n -> () n ()')
237
+ logits_range = rearrange(logits_range, 'd -> () () d')
238
+
239
+ logits_mask = (
240
+ ((seq_range >= text_seq_len) & (logits_range < num_text_tokens)) |
241
+ ((seq_range < text_seq_len) & (logits_range >= num_text_tokens))
242
+ )
243
+
244
+ self.register_buffer('logits_mask', logits_mask, persistent=False)
245
+ self.loss_img_weight = loss_img_weight
246
+ self.loss_txt_weight = loss_txt_weight
247
+
248
+ self.apply(self._init_weights)
249
+
250
+ def _init_weights(self, module):
251
+ if isinstance(module, (nn.Linear, nn.Embedding)):
252
+ module.weight.data.normal_(mean=0.0, std=0.02)
253
+ if isinstance(module, nn.Linear) and module.bias is not None:
254
+ module.bias.data.zero_()
255
+ elif isinstance(module, nn.LayerNorm):
256
+ module.bias.data.zero_()
257
+ module.weight.data.fill_(1.0)
258
+
259
+ @torch.no_grad()
260
+ @eval_decorator
261
+ def generate_images(
262
+ self,
263
+ text,
264
+ decoder,
265
+ *,
266
+ clip = None,
267
+ filter_thres = 0.5,
268
+ temperature = 1.,
269
+ img = None,
270
+ num_init_img_tokens = None
271
+ ):
272
+ text_seq_len, image_seq_len, num_text_tokens = self.text_seq_len, self.image_seq_len, self.num_text_tokens
273
+ total_len = text_seq_len + image_seq_len
274
+
275
+ batch = text.shape[0]
276
+ text = text[:, :text_seq_len] # make sure text is within bounds
277
+ out = text
278
+
279
+ if exists(img):
280
+ assert img.shape[1] == image_seq_len, f'input image must have the correct image size {image_seq_len}'
281
+
282
+ num_img_tokens = default(num_init_img_tokens, int(0.4375 * image_seq_len)) # OpenAI used 14 * 32 initial tokens to prime
283
+ assert num_img_tokens < image_seq_len, 'number of initial image tokens for priming must be less than the total image token sequence length'
284
+
285
+ trunc_img = img[:, :num_img_tokens]
286
+ out = torch.cat((out, trunc_img), dim = -1)
287
+
288
+ for cur_len in range(out.shape[1], total_len):
289
+ is_image = cur_len >= text_seq_len
290
+
291
+ text, image = out[:, :text_seq_len], out[:, text_seq_len:]
292
+
293
+ logits = self(text, image)[:, -1, :]
294
+
295
+ filtered_logits = top_k(logits, thres = filter_thres)
296
+ probs = F.softmax(filtered_logits / temperature, dim = -1)
297
+ sample = torch.multinomial(probs, 1)
298
+
299
+ sample -= (num_text_tokens if is_image else 0) # offset sampled token if it is an image token, since logit space is composed of text and then image tokens
300
+ out = torch.cat((out, sample), dim=-1)
301
+
302
+ text_seq = out[:, :text_seq_len]
303
+ img_seq = out[:, -image_seq_len:]
304
+
305
+ scores = None
306
+ if exists(clip):
307
+ scores = clip(text_seq, img_seq, return_loss = False)
308
+
309
+ img_seq = repeat(img_seq, 'b p -> b p c', c=3)
310
+ decoder = repeat(decoder, 'p c -> b p c', b=batch)
311
+ images = torch.gather(decoder, 1, img_seq)
312
+ images = rearrange(images, 'b (h w) c-> b c h w', h=24, w =24)
313
+ images = images.float()
314
+
315
+ return images, scores
316
+
317
+ def forward(
318
+ self,
319
+ text,
320
+ image = None,
321
+ return_loss = False
322
+ ):
323
+ assert text.shape[-1] == self.text_seq_len, f'the length {text.shape[-1]} of the text tokens you passed in does not have the correct length ({self.text_seq_len})'
324
+ device, total_seq_len = text.device, self.total_seq_len
325
+
326
+ text_range = torch.arange(self.text_seq_len, device = device) + (self.num_text_tokens - self.text_seq_len)
327
+ text = torch.where(text == 0, text_range, text)
328
+
329
+ text = F.pad(text, (1, 0), value = 0) # add <bos>
330
+
331
+ tokens = self.text_emb(text)
332
+ tokens += self.text_pos_emb(torch.arange(text.shape[1], device = device))
333
+
334
+ seq_len = tokens.shape[1]
335
+
336
+ image_len = image.shape[1]
337
+ image_emb = self.image_emb(image)
338
+ # image_emb += self.image_pos_emb(torch.arange(image_len, device = device))
339
+ image_emb += self.image_pos_emb[:, :image_len, :]
340
+
341
+ # image_emb += self.image_pos_emb(image_emb)
342
+
343
+ tokens = torch.cat((tokens, image_emb), dim = 1)
344
+
345
+ seq_len += image_len
346
+
347
+ # when training, if the length exceeds the total text + image length
348
+ # remove the last token, since it needs not to be trained
349
+
350
+ if tokens.shape[1] > total_seq_len:
351
+ seq_len -= 1
352
+ tokens = tokens[:, :-1]
353
+
354
+ out = self.transformer(tokens)
355
+ logits = self.to_logits(out)
356
+
357
+ # mask logits to make sure text predicts text (except last token), and image predicts image
358
+
359
+ logits_mask = self.logits_mask[:, :seq_len]
360
+ max_neg_value = -torch.finfo(logits.dtype).max
361
+ logits.masked_fill_(logits_mask, max_neg_value)
362
+
363
+ if not return_loss:
364
+ return logits
365
+
366
+ assert exists(image), 'when training, image must be supplied'
367
+
368
+ offsetted_image = image + self.num_text_tokens
369
+ labels = torch.cat((text[:, 1:], offsetted_image), dim = 1)
370
+
371
+ logits = rearrange(logits, 'b n c -> b c n')
372
+
373
+ loss_text = F.cross_entropy(logits[:, :, :self.text_seq_len], labels[:, :self.text_seq_len])
374
+ loss_img = F.cross_entropy(logits[:, :, self.text_seq_len:], labels[:, self.text_seq_len:])
375
+
376
+ loss = (self.loss_txt_weight * loss_text + self.loss_img_weight * loss_img) / (self.loss_img_weight + self.loss_txt_weight)
377
+ return loss, loss_text, loss_img
text2punks/tokenizer.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import html
3
+ import ftfy
4
+ import regex as re
5
+ from pathlib import Path
6
+
7
+ import torch
8
+
9
+ from functools import lru_cache
10
+
11
+ import youtokentome as yttm
12
+ from tokenizers import Tokenizer
13
+ from tokenizers.processors import ByteLevel
14
+
15
+
16
+ # OpenAI simple tokenizer
17
+
18
+ @lru_cache()
19
+ def default_bpe(bpe_path = "data/bpe_simple_vocab_16e6.txt"):
20
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), bpe_path)
21
+
22
+ @lru_cache()
23
+ def bytes_to_unicode():
24
+ bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
25
+ cs = bs[:]
26
+ n = 0
27
+ for b in range(2 ** 8):
28
+ if b not in bs:
29
+ bs.append(b)
30
+ cs.append(2 ** 8 + n)
31
+ n += 1
32
+ cs = [chr(n) for n in cs]
33
+ return dict(zip(bs, cs))
34
+
35
+ def get_pairs(word):
36
+ pairs = set()
37
+ prev_char = word[0]
38
+ for char in word[1:]:
39
+ pairs.add((prev_char, char))
40
+ prev_char = char
41
+ return pairs
42
+
43
+ def basic_clean(text):
44
+ text = ftfy.fix_text(text)
45
+ text = html.unescape(html.unescape(text))
46
+ return text.strip()
47
+
48
+ def whitespace_clean(text):
49
+ text = re.sub(r'\s+', ' ', text)
50
+ text = text.strip()
51
+ return text
52
+
53
+
54
+ class SimpleTokenizer(object):
55
+ def __init__(self, bpe_path = default_bpe()):
56
+ self.byte_encoder = bytes_to_unicode()
57
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
58
+ merges = Path(bpe_path).read_text(encoding='utf8').split('\n')
59
+ merges = merges[1:49152 - 256 - 2 + 1]
60
+ merges = [tuple(merge.split()) for merge in merges]
61
+ vocab = list(bytes_to_unicode().values())
62
+ vocab = vocab + [v + '</w>' for v in vocab]
63
+ for merge in merges:
64
+ vocab.append(''.join(merge))
65
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
66
+
67
+ self.vocab_size = 49408
68
+
69
+ self.encoder = dict(zip(vocab, range(len(vocab))))
70
+ self.decoder = {v: k for k, v in self.encoder.items()}
71
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
72
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
73
+ self.pat = re.compile(
74
+ r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
75
+ re.IGNORECASE)
76
+
77
+ def bpe(self, token):
78
+ if token in self.cache:
79
+ return self.cache[token]
80
+ word = tuple(token[:-1]) + (token[-1] + '</w>',)
81
+ pairs = get_pairs(word)
82
+
83
+ if not pairs:
84
+ return token + '</w>'
85
+
86
+ while True:
87
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
88
+ if bigram not in self.bpe_ranks:
89
+ break
90
+ first, second = bigram
91
+ new_word = []
92
+ i = 0
93
+ while i < len(word):
94
+ try:
95
+ j = word.index(first, i)
96
+ new_word.extend(word[i:j])
97
+ i = j
98
+ except:
99
+ new_word.extend(word[i:])
100
+ break
101
+
102
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
103
+ new_word.append(first + second)
104
+ i += 2
105
+ else:
106
+ new_word.append(word[i])
107
+ i += 1
108
+ new_word = tuple(new_word)
109
+ word = new_word
110
+ if len(word) == 1:
111
+ break
112
+ else:
113
+ pairs = get_pairs(word)
114
+ word = ' '.join(word)
115
+ self.cache[token] = word
116
+ return word
117
+
118
+ def encode(self, text):
119
+ bpe_tokens = []
120
+ text = whitespace_clean(basic_clean(text)).lower()
121
+ for token in re.findall(self.pat, text):
122
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
123
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
124
+ return bpe_tokens
125
+
126
+ def decode(self, tokens, remove_start_end = True):
127
+ if torch.is_tensor(tokens):
128
+ tokens = tokens.tolist()
129
+
130
+ if remove_start_end:
131
+ tokens = [token for token in tokens if token not in (49406, 40407, 0)]
132
+ text = ''.join([self.decoder[token] for token in tokens])
133
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
134
+ return text
135
+
136
+ def tokenize(self, texts, context_length = 256, truncate_text = False):
137
+ if isinstance(texts, str):
138
+ texts = [texts]
139
+
140
+ all_tokens = [self.encode(text) for text in texts]
141
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
142
+
143
+ for i, tokens in enumerate(all_tokens):
144
+ if len(tokens) > context_length:
145
+ if truncate_text:
146
+ tokens = tokens[:context_length]
147
+ else:
148
+ raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
149
+ result[i, :len(tokens)] = torch.tensor(tokens)
150
+
151
+ return result
152
+
153
+ # txt_tokenizer = SimpleTokenizer()
154
+
155
+ # huggingface tokenizer
156
+
157
+ class HugTokenizer:
158
+ def __init__(self, bpe_path):
159
+ bpe_path = Path(default_bpe(bpe_path = bpe_path))
160
+ assert bpe_path.exists(), f'BPE json path {str(bpe_path)} does not exist'
161
+ tokenizer = Tokenizer.from_file(str(bpe_path))
162
+ tokenizer.post_processor = ByteLevel(trim_offsets = True)
163
+ self.tokenizer = tokenizer
164
+ self.vocab_size = tokenizer.get_vocab_size()
165
+
166
+ def decode(self, tokens):
167
+ if torch.is_tensor(tokens):
168
+ tokens = tokens.tolist()
169
+
170
+ tokens = [token for token in tokens if token not in (0,)]
171
+ return self.tokenizer.decode(tokens, skip_special_tokens = True)
172
+
173
+ def encode(self, text):
174
+ return self.tokenizer.encode(text).ids
175
+
176
+ def tokenize(self, texts, context_length = 256, truncate_text = False):
177
+ if isinstance(texts, str):
178
+ texts = [texts]
179
+
180
+ all_tokens = [self.encode(text) for text in texts]
181
+
182
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
183
+ for i, tokens in enumerate(all_tokens):
184
+ if len(tokens) > context_length:
185
+ if truncate_text:
186
+ tokens = tokens[:context_length]
187
+ else:
188
+ raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
189
+ result[i, :len(tokens)] = torch.tensor(tokens)
190
+
191
+ return result
192
+
193
+ txt_tokenizer = HugTokenizer(bpe_path = "data/byte-level-bpe_4k.tokenizer.json")
194
+
195
+ # yttm tokenizer
196
+
197
+ class YttmTokenizer:
198
+ def __init__(self, bpe_path = None):
199
+ bpe_path = Path(default_bpe(bpe_path = bpe_path))
200
+ assert bpe_path.exists(), f'BPE json path {str(bpe_path)} does not exist'
201
+
202
+ tokenizer = yttm.BPE(model = str(bpe_path))
203
+ self.tokenizer = tokenizer
204
+ self.vocab_size = tokenizer.vocab_size()
205
+
206
+ def decode(self, tokens):
207
+ if torch.is_tensor(tokens):
208
+ tokens = tokens.tolist()
209
+
210
+ return self.tokenizer.decode(tokens, ignore_ids = [0])
211
+
212
+ def encode(self, texts):
213
+ encoded = self.tokenizer.encode(texts, output_type = yttm.OutputType.ID)
214
+ return list(map(torch.tensor, encoded))
215
+
216
+ def tokenize(self, texts, context_length = 256, truncate_text = False):
217
+ if isinstance(texts, str):
218
+ texts = [texts]
219
+
220
+ all_tokens = self.encode(texts)
221
+
222
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
223
+ for i, tokens in enumerate(all_tokens):
224
+ if len(tokens) > context_length:
225
+ if truncate_text:
226
+ tokens = tokens[:context_length]
227
+ else:
228
+ raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
229
+ result[i, :len(tokens)] = tokens.detach().clone()
230
+
231
+ return result
232
+
233
+ # txt_tokenizer = YttmTokenizer(bpe_path = "data/byte-level-bpe.tokenizer.json")
text2punks/transformer.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from itertools import islice, cycle
3
+
4
+ from torch import nn
5
+
6
+ from text2punks.attention import Attention, SparseAxialCausalAttention
7
+
8
+ # helpers
9
+
10
+ def exists(val):
11
+ return val is not None
12
+
13
+ def default(val, d):
14
+ return val if exists(val) else d
15
+
16
+ def cast_tuple(val, depth = 1):
17
+ if isinstance(val, list):
18
+ val = tuple(val)
19
+ return val if isinstance(val, tuple) else (val,) * depth
20
+
21
+ # classes
22
+
23
+ class SequentialSequence(nn.Module):
24
+ def __init__(self, layers):
25
+ super().__init__()
26
+ self.layers = layers
27
+
28
+ def forward(self, x):
29
+ for (f, g) in list(self.layers):
30
+ x = x + f(x)
31
+ x = x + g(x)
32
+ return x
33
+
34
+ class PreNorm(nn.Module):
35
+ def __init__(self, dim, fn):
36
+ super().__init__()
37
+ self.norm = nn.LayerNorm(dim)
38
+ self.fn = fn
39
+
40
+ def forward(self, x, **kwargs):
41
+ return self.fn(self.norm(x), **kwargs)
42
+
43
+ class FeedForward(nn.Module):
44
+ def __init__(self, dim, dropout = 0.):
45
+ super().__init__()
46
+ self.net = nn.Sequential(
47
+ nn.Linear(dim, dim * 4),
48
+ nn.GELU(),
49
+ nn.Dropout(dropout),
50
+ nn.Linear(dim * 4, dim)
51
+ )
52
+ # the order of dropout nn.Linear(4 * n_embd, n_embd) vs nn.Dropout(resid_pdrop)
53
+
54
+ def forward(self, x):
55
+ return self.net(x)
56
+
57
+
58
+ class Transformer(nn.Module):
59
+ def __init__(
60
+ self,
61
+ *,
62
+ dim,
63
+ depth,
64
+ seq_len,
65
+ causal = True,
66
+ heads = 8,
67
+ dim_head = 64,
68
+ attn_dropout = 0.,
69
+ resid_dropout = 0.,
70
+ embd_dropout = 0.,
71
+ ff_dropout = 0.,
72
+ image_size = 24,
73
+ attn_types = None,
74
+ ):
75
+ super().__init__()
76
+ layers = nn.ModuleList([])
77
+
78
+ attn_types = default(attn_types, ('full',))
79
+ attn_types = cast_tuple(attn_types)
80
+ attn_type_layer = islice(cycle(attn_types), depth)
81
+
82
+ for attn_type in attn_type_layer:
83
+ if attn_type == 'full':
84
+ attn_class = partial(Attention, causal = causal)
85
+ elif attn_type == 'axial_row':
86
+ attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 0, image_size = image_size)
87
+ elif attn_type == 'axial_col':
88
+ attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 1, image_size = image_size)
89
+ else:
90
+ raise ValueError(f'attention type "{attn_type}" is not valid')
91
+
92
+ attn = attn_class(dim, seq_len = seq_len, heads = heads, dim_head = dim_head, attn_dropout = attn_dropout, resid_dropout = resid_dropout)
93
+
94
+ layers.append(nn.ModuleList([
95
+ PreNorm(dim, attn),
96
+ PreNorm(dim, FeedForward(dim, dropout = ff_dropout))
97
+ ]))
98
+
99
+ # full attention in the last layer
100
+
101
+ attn_class = partial(Attention, causal = causal)
102
+ attn = attn_class(dim, seq_len = seq_len, heads = heads, dim_head = dim_head, attn_dropout = attn_dropout, resid_dropout = resid_dropout)
103
+
104
+ layers.append(nn.ModuleList([
105
+ PreNorm(dim, attn),
106
+ PreNorm(dim, FeedForward(dim, dropout = ff_dropout))
107
+ ]))
108
+
109
+ self.layers = SequentialSequence(layers)
110
+ self.embd_drop = nn.Dropout(embd_dropout)
111
+
112
+ def forward(self, x):
113
+ x = self.embd_drop(x)
114
+ return self.layers(x)
115
+
text2punks/utils.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # os
2
+
3
+ from pathlib import Path
4
+
5
+ # torch
6
+
7
+ import torch
8
+ import torchvision.transforms.functional as F
9
+ from einops import repeat
10
+
11
+ # Text2Punks and Tokenizer
12
+
13
+ from text2punks.text2punk import Text2Punks, CLIP
14
+ from text2punks.tokenizer import txt_tokenizer
15
+
16
+ # select device
17
+
18
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
19
+
20
+ # load decoder
21
+
22
+ codebook = torch.load('./text2punks/data/codebook.pt')
23
+
24
+ # helper fns
25
+
26
+ def exists(val):
27
+ return val is not None
28
+
29
+
30
+ def to_pil_image(image_tensor):
31
+ return F.to_pil_image(image_tensor)
32
+
33
+
34
+ def model_loader(text2punk_path, clip_path):
35
+ # load pre-trained TEXT2PUNKS model
36
+
37
+ text2punk_path = Path(text2punk_path)
38
+ assert text2punk_path.exists(), 'trained Text2Punks must exist'
39
+
40
+ load_obj = torch.load(str(text2punk_path), map_location=torch.device(device))
41
+ text2punks_params, weights = load_obj.pop('hparams'), load_obj.pop('weights')
42
+
43
+ text2punk = Text2Punks(**text2punks_params).to(device)
44
+ text2punk.load_state_dict(weights)
45
+
46
+ # load pre-trained CLIP model
47
+
48
+ clip_path = Path(clip_path)
49
+ assert clip_path.exists(), 'trained CLIP must exist'
50
+
51
+ load_obj = torch.load(str(clip_path), map_location=torch.device(device))
52
+ clip_params, weights = load_obj.pop('hparams'), load_obj.pop('weights')
53
+
54
+ clip = CLIP(**clip_params).to(device)
55
+ clip.load_state_dict(weights)
56
+
57
+ return text2punk, clip
58
+
59
+
60
+ def generate_image(prompt_text, top_k, temperature, num_images, batch_size, top_prediction, text2punk_model, clip_model, codebook=codebook):
61
+ text = txt_tokenizer.tokenize(prompt_text, text2punk_model.text_seq_len, truncate_text=True).to(device)
62
+
63
+ text = repeat(text, '() n -> b n', b = num_images)
64
+
65
+ img_outputs = []
66
+ score_outputs = []
67
+
68
+ for text_chunk in text.split(batch_size):
69
+ images, scores = text2punk_model.generate_images(text_chunk, codebook.to(device), clip = clip_model, filter_thres = top_k, temperature = temperature)
70
+ img_outputs.append(images)
71
+ score_outputs.append(scores)
72
+
73
+ img_outputs = torch.cat(img_outputs)
74
+ score_outputs = torch.cat(score_outputs)
75
+
76
+ similarity = score_outputs.softmax(dim=-1)
77
+ values, indices = similarity.topk(top_prediction)
78
+
79
+ img_outputs = img_outputs[indices]
80
+ score_outputs = score_outputs[indices]
81
+
82
+ return img_outputs, score_outputs