Spaces:
Build error
Build error
Khalil
commited on
Commit
•
b41a54a
1
Parent(s):
416f940
First commit, add text2punps scripts, app file, and requirements file
Browse files- app.py +82 -0
- requirements.txt +9 -0
- text2punks/attention.py +175 -0
- text2punks/data/byte-level-bpe_4k.tokenizer.json +969 -0
- text2punks/data/codebook.pt +0 -0
- text2punks/loader.py +96 -0
- text2punks/text2punk.py +377 -0
- text2punks/tokenizer.py +233 -0
- text2punks/transformer.py +115 -0
- text2punks/utils.py +82 -0
app.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# system
|
2 |
+
|
3 |
+
import os
|
4 |
+
|
5 |
+
os.system("gdown https://drive.google.com/uc?id=1--27E5dk8GzgvpVL0ofr-m631iymBpUH")
|
6 |
+
os.system("gdown https://drive.google.com/uc?id=191a5lTsUPQ1hXaeo6kVNbo_W3WYuXsmF")
|
7 |
+
|
8 |
+
# plot
|
9 |
+
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import numpy as np
|
12 |
+
from PIL import Image
|
13 |
+
|
14 |
+
# gradio
|
15 |
+
|
16 |
+
import gradio as gr
|
17 |
+
|
18 |
+
# text2punks utils
|
19 |
+
|
20 |
+
from text2punks.utils import to_pil_image, model_loader, generate_image
|
21 |
+
|
22 |
+
|
23 |
+
batch_size = 32
|
24 |
+
num_images = 32
|
25 |
+
top_prediction = 8
|
26 |
+
|
27 |
+
# nobs to tune
|
28 |
+
|
29 |
+
top_k = 0.8
|
30 |
+
temperature = 1.25
|
31 |
+
|
32 |
+
# helper functions
|
33 |
+
|
34 |
+
def compose_predictions(images):
|
35 |
+
|
36 |
+
increased_h = 0
|
37 |
+
h, w = images[0].shape[0], images[0].shape[1]
|
38 |
+
image_grid = Image.new("RGB", (len(images)*w, h))
|
39 |
+
|
40 |
+
for i, img_ in enumerate(images):
|
41 |
+
image_grid.paste(to_pil_image(img_), (i*w, increased_h))
|
42 |
+
|
43 |
+
return img
|
44 |
+
|
45 |
+
|
46 |
+
def run_inference(prompt, num_images=32, num_preds=8):
|
47 |
+
|
48 |
+
t2p_path, clip_path = './Text2Punk-final-7.pt', './clip-final.pt'
|
49 |
+
text2punk, clip = model_loader(t2p_path, clip_path)
|
50 |
+
|
51 |
+
images = generate_image(prompt_text=prompt, top_k=top_k, temperature=temperature, num_images=num_images, batch_size=batch_size, top_prediction=top_prediction, text2punk_model=text2punk, clip_model=clip)
|
52 |
+
predictions = compose_predictions(images)
|
53 |
+
|
54 |
+
output_title = f"""
|
55 |
+
<b>{prompt}</b>
|
56 |
+
"""
|
57 |
+
|
58 |
+
return (output_title, predictions)
|
59 |
+
|
60 |
+
|
61 |
+
outputs = [
|
62 |
+
gr.outputs.HTML(label=""), # To be used as title
|
63 |
+
gr.outputs.Image(label=''),
|
64 |
+
]
|
65 |
+
|
66 |
+
description = """
|
67 |
+
Text2Cryptopunks is an AI model that generates Cryptopunks images from text prompt:
|
68 |
+
"""
|
69 |
+
|
70 |
+
gr.Interface(run_inference,
|
71 |
+
inputs=[gr.inputs.Textbox(label='type somthing like this : "An Ape CryptoPunk that has 2 Attributes, a Pigtails and a Medical Mask."')],
|
72 |
+
outputs=outputs,
|
73 |
+
title='Text2Cryptopunks',
|
74 |
+
description=description,
|
75 |
+
article="<p style='text-align: center'> Created by kTonpa | <a href='https://github.com/kTonpa/Text2CryptoPunks'>GitHub</a>",
|
76 |
+
layout='vertical',
|
77 |
+
theme='huggingface',
|
78 |
+
examples=[['Cute Alien cryptopunk that has a 2 Attributes, a Pipe, and a Beanie.'], ['A low resolution photo of punky-looking Ape that has 2 Attributes, a Beanie, and a Medical Mask.']],
|
79 |
+
allow_flagging=False,
|
80 |
+
live=False,
|
81 |
+
# server_port=8999
|
82 |
+
).launch(share=True)
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
einops
|
4 |
+
numpy
|
5 |
+
ftfy
|
6 |
+
regex
|
7 |
+
axial-positional-embedding
|
8 |
+
youtokentome
|
9 |
+
tokenizers
|
text2punks/attention.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn, einsum
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from einops import rearrange, repeat
|
5 |
+
|
6 |
+
# helpers
|
7 |
+
|
8 |
+
def exists(val):
|
9 |
+
return val is not None
|
10 |
+
|
11 |
+
def max_neg_value(t):
|
12 |
+
return -torch.finfo(t.dtype).max
|
13 |
+
|
14 |
+
|
15 |
+
# classes
|
16 |
+
|
17 |
+
class Attention(nn.Module):
|
18 |
+
def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, attn_dropout = 0., resid_dropout = 0.):
|
19 |
+
super().__init__()
|
20 |
+
inner_dim = dim_head * heads
|
21 |
+
self.heads = heads
|
22 |
+
self.seq_len = seq_len
|
23 |
+
self.scale = dim_head ** -0.5
|
24 |
+
|
25 |
+
self.causal = causal
|
26 |
+
self.attn_drop = nn.Dropout(attn_dropout)
|
27 |
+
|
28 |
+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
29 |
+
self.to_out = nn.Sequential(
|
30 |
+
nn.Linear(inner_dim, dim),
|
31 |
+
nn.Dropout(resid_dropout)
|
32 |
+
)
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
h, device = self.heads, x.device
|
36 |
+
|
37 |
+
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
38 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
39 |
+
|
40 |
+
q = q * self.scale
|
41 |
+
|
42 |
+
dots = torch.einsum('b h i d, b h j d -> b h i j', q, k)
|
43 |
+
mask_value = max_neg_value(dots)
|
44 |
+
|
45 |
+
if self.causal:
|
46 |
+
i, j = dots.shape[-2:]
|
47 |
+
mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
|
48 |
+
dots.masked_fill_(mask, mask_value)
|
49 |
+
|
50 |
+
attn = torch.softmax(dots, dim=-1)
|
51 |
+
attn = self.attn_drop(attn)
|
52 |
+
|
53 |
+
out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
|
54 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
55 |
+
out = self.to_out(out)
|
56 |
+
return out
|
57 |
+
|
58 |
+
|
59 |
+
# sparse axial causal attention
|
60 |
+
|
61 |
+
class SparseAxialCausalAttention(nn.Module):
|
62 |
+
def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head = 64, attn_dropout = 0., resid_dropout = 0.):
|
63 |
+
super().__init__()
|
64 |
+
assert axis in {0, 1}, 'axis must be either 0 (along height) or 1 (along width)'
|
65 |
+
self.axis = axis
|
66 |
+
|
67 |
+
inner_dim = dim_head * heads
|
68 |
+
self.seq_len = seq_len
|
69 |
+
self.heads = heads
|
70 |
+
self.scale = dim_head ** -0.5
|
71 |
+
self.image_size = image_size
|
72 |
+
self.attn_drop = nn.Dropout(attn_dropout)
|
73 |
+
|
74 |
+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
75 |
+
|
76 |
+
self.to_out = nn.Sequential(
|
77 |
+
nn.Linear(inner_dim, dim),
|
78 |
+
nn.Dropout(resid_dropout)
|
79 |
+
)
|
80 |
+
|
81 |
+
def forward(self, x):
|
82 |
+
b, n, _, h, img_size, axis, seq_len, device = *x.shape, self.heads, self.image_size, self.axis, self.seq_len, x.device
|
83 |
+
|
84 |
+
img_seq_len = img_size ** 2
|
85 |
+
text_len = seq_len + 1 - img_seq_len
|
86 |
+
|
87 |
+
# padding
|
88 |
+
|
89 |
+
padding = seq_len - n + 1
|
90 |
+
mask = torch.ones(b, text_len, device = device).bool()
|
91 |
+
|
92 |
+
x = F.pad(x, (0, 0, 0, padding), value = 0)
|
93 |
+
mask = mask[:, :text_len]
|
94 |
+
|
95 |
+
# derive queries / keys / values
|
96 |
+
|
97 |
+
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
98 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)
|
99 |
+
|
100 |
+
# print(self.scale)
|
101 |
+
q = q * self.scale
|
102 |
+
|
103 |
+
((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map(lambda t: (t[:, :-img_seq_len], t[:, -img_seq_len:]), (q, k, v))
|
104 |
+
|
105 |
+
# text attention
|
106 |
+
|
107 |
+
dots_text = einsum('b i d, b j d -> b i j', q_text, k_text)
|
108 |
+
mask_value = max_neg_value(dots_text)
|
109 |
+
|
110 |
+
i, j = dots_text.shape[-2:]
|
111 |
+
text_causal_mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
|
112 |
+
dots_text.masked_fill_(text_causal_mask, mask_value)
|
113 |
+
|
114 |
+
attn_text = torch.softmax(dots_text, dim = -1)
|
115 |
+
|
116 |
+
# attention dropout
|
117 |
+
|
118 |
+
attn_text = self.attn_drop(attn_text)
|
119 |
+
out_text = einsum('b i j, b j d -> b i d', attn_text, v_text)
|
120 |
+
|
121 |
+
# image attention
|
122 |
+
|
123 |
+
split_axis_einops = 'b (h w) c -> b h w c' if axis == 0 else 'b (h w) c -> b w h c'
|
124 |
+
merge_axis_einops = 'b x n d -> b (x n) d' if axis == 0 else 'b x n d -> b (n x) d'
|
125 |
+
|
126 |
+
# split out axis
|
127 |
+
|
128 |
+
q_img, k_img, v_img = map(lambda t: rearrange(t, split_axis_einops, h = img_size), (q_img, k_img, v_img))
|
129 |
+
|
130 |
+
# similarity
|
131 |
+
|
132 |
+
dots_image_to_image = einsum('b x i d, b x j d -> b x i j', q_img, k_img)
|
133 |
+
dots_image_to_text = einsum('b x i d, b j d -> b x i j', q_img, k_text)
|
134 |
+
|
135 |
+
dots = torch.cat((dots_image_to_text, dots_image_to_image), dim = -1)
|
136 |
+
|
137 |
+
# mask so image has full attention to text, but causal along axis
|
138 |
+
|
139 |
+
bh, x, i, j = dots.shape
|
140 |
+
causal_mask = torch.ones(i, img_size, device = device).triu_(img_size - i + 1).bool()
|
141 |
+
causal_mask = repeat(causal_mask, 'i j -> b x i j', b = bh, x = x)
|
142 |
+
|
143 |
+
mask = repeat(mask, 'b j -> (b h) x i j', h = h, x = x, i = i)
|
144 |
+
mask = torch.cat((~mask, causal_mask), dim = -1)
|
145 |
+
|
146 |
+
dots.masked_fill_(mask, mask_value)
|
147 |
+
|
148 |
+
# attention.
|
149 |
+
|
150 |
+
attn = torch.softmax(dots, dim = -1)
|
151 |
+
|
152 |
+
# attention dropout
|
153 |
+
|
154 |
+
attn = self.attn_drop(attn)
|
155 |
+
|
156 |
+
# aggregate
|
157 |
+
|
158 |
+
attn_image_to_text, attn_image_to_image = attn[..., :text_len], attn[..., text_len:]
|
159 |
+
|
160 |
+
out_image_to_image = einsum('b x i j, b x j d -> b x i d', attn_image_to_image, v_img)
|
161 |
+
out_image_to_text = einsum('b x i j, b j d -> b x i d', attn_image_to_text, v_text)
|
162 |
+
|
163 |
+
out_image = out_image_to_image + out_image_to_text
|
164 |
+
|
165 |
+
# merge back axis
|
166 |
+
|
167 |
+
out_image = rearrange(out_image, merge_axis_einops, x = img_size)
|
168 |
+
|
169 |
+
# combine attended values for both text and image
|
170 |
+
|
171 |
+
out = torch.cat((out_text, out_image), dim = 1)
|
172 |
+
|
173 |
+
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
|
174 |
+
out = self.to_out(out)
|
175 |
+
return out[:, :n]
|
text2punks/data/byte-level-bpe_4k.tokenizer.json
ADDED
@@ -0,0 +1,969 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"version": "1.0",
|
3 |
+
"truncation": null,
|
4 |
+
"padding": null,
|
5 |
+
"added_tokens": [
|
6 |
+
{
|
7 |
+
"id": 0,
|
8 |
+
"special": true,
|
9 |
+
"content": "[PAD]",
|
10 |
+
"single_word": false,
|
11 |
+
"lstrip": false,
|
12 |
+
"rstrip": false,
|
13 |
+
"normalized": false
|
14 |
+
},
|
15 |
+
{
|
16 |
+
"id": 1,
|
17 |
+
"special": true,
|
18 |
+
"content": "[SEP]",
|
19 |
+
"single_word": false,
|
20 |
+
"lstrip": false,
|
21 |
+
"rstrip": false,
|
22 |
+
"normalized": false
|
23 |
+
}
|
24 |
+
],
|
25 |
+
"normalizer": {
|
26 |
+
"type": "Lowercase"
|
27 |
+
},
|
28 |
+
"pre_tokenizer": {
|
29 |
+
"type": "ByteLevel",
|
30 |
+
"add_prefix_space": false,
|
31 |
+
"trim_offsets": true
|
32 |
+
},
|
33 |
+
"post_processor": {
|
34 |
+
"type": "ByteLevel",
|
35 |
+
"add_prefix_space": true,
|
36 |
+
"trim_offsets": true
|
37 |
+
},
|
38 |
+
"decoder": {
|
39 |
+
"type": "ByteLevel",
|
40 |
+
"add_prefix_space": true,
|
41 |
+
"trim_offsets": true
|
42 |
+
},
|
43 |
+
"model": {
|
44 |
+
"type": "BPE",
|
45 |
+
"dropout": null,
|
46 |
+
"unk_token": null,
|
47 |
+
"continuing_subword_prefix": null,
|
48 |
+
"end_of_word_suffix": null,
|
49 |
+
"fuse_unk": false,
|
50 |
+
"vocab": {
|
51 |
+
"[PAD]": 0,
|
52 |
+
"[SEP]": 1,
|
53 |
+
",": 2,
|
54 |
+
"-": 3,
|
55 |
+
".": 4,
|
56 |
+
"0": 5,
|
57 |
+
"1": 6,
|
58 |
+
"2": 7,
|
59 |
+
"3": 8,
|
60 |
+
"4": 9,
|
61 |
+
"5": 10,
|
62 |
+
"6": 11,
|
63 |
+
"7": 12,
|
64 |
+
"?": 13,
|
65 |
+
"a": 14,
|
66 |
+
"b": 15,
|
67 |
+
"c": 16,
|
68 |
+
"d": 17,
|
69 |
+
"e": 18,
|
70 |
+
"f": 19,
|
71 |
+
"g": 20,
|
72 |
+
"h": 21,
|
73 |
+
"i": 22,
|
74 |
+
"k": 23,
|
75 |
+
"l": 24,
|
76 |
+
"m": 25,
|
77 |
+
"n": 26,
|
78 |
+
"o": 27,
|
79 |
+
"p": 28,
|
80 |
+
"r": 29,
|
81 |
+
"s": 30,
|
82 |
+
"t": 31,
|
83 |
+
"u": 32,
|
84 |
+
"v": 33,
|
85 |
+
"w": 34,
|
86 |
+
"x": 35,
|
87 |
+
"y": 36,
|
88 |
+
"z": 37,
|
89 |
+
"Ċ": 38,
|
90 |
+
"Ġ": 39,
|
91 |
+
"Ġa": 40,
|
92 |
+
"nd": 41,
|
93 |
+
"Ġb": 42,
|
94 |
+
"ha": 43,
|
95 |
+
"le": 44,
|
96 |
+
"ma": 45,
|
97 |
+
"Ġc": 46,
|
98 |
+
"ro": 47,
|
99 |
+
"pu": 48,
|
100 |
+
"ck": 49,
|
101 |
+
"to": 50,
|
102 |
+
"Ġand": 51,
|
103 |
+
"ack": 52,
|
104 |
+
"ar": 53,
|
105 |
+
"Ġma": 54,
|
106 |
+
"nk": 55,
|
107 |
+
"gro": 56,
|
108 |
+
"und": 57,
|
109 |
+
"Ġback": 58,
|
110 |
+
"punk": 59,
|
111 |
+
"ground": 60,
|
112 |
+
"Ġbackground": 61,
|
113 |
+
"Ġha": 62,
|
114 |
+
"Ġcr": 63,
|
115 |
+
"in": 64,
|
116 |
+
"Ġs": 65,
|
117 |
+
"Ġo": 66,
|
118 |
+
"pto": 67,
|
119 |
+
"ypto": 68,
|
120 |
+
"Ġcrypto": 69,
|
121 |
+
"Ġcryptopunk": 70,
|
122 |
+
"Ġof": 71,
|
123 |
+
"es": 72,
|
124 |
+
"Ġw": 73,
|
125 |
+
"Ġwi": 74,
|
126 |
+
"th": 75,
|
127 |
+
"ing": 76,
|
128 |
+
"ho": 77,
|
129 |
+
"lo": 78,
|
130 |
+
"Ġwith": 79,
|
131 |
+
"Ġp": 80,
|
132 |
+
"Ġmale": 81,
|
133 |
+
"Ġg": 82,
|
134 |
+
"Ġf": 83,
|
135 |
+
"ear": 84,
|
136 |
+
"Ġhas": 85,
|
137 |
+
"Ġm": 86,
|
138 |
+
"lu": 87,
|
139 |
+
"Ġt": 88,
|
140 |
+
"re": 89,
|
141 |
+
"Ġfe": 90,
|
142 |
+
"de": 91,
|
143 |
+
"wn": 92,
|
144 |
+
"ok": 93,
|
145 |
+
"hoto": 94,
|
146 |
+
"look": 95,
|
147 |
+
"Ġphoto": 96,
|
148 |
+
"hat": 97,
|
149 |
+
"Ġthat": 98,
|
150 |
+
"Ġmade": 99,
|
151 |
+
"male": 100,
|
152 |
+
"Ġfemale": 101,
|
153 |
+
"tr": 102,
|
154 |
+
"ir": 103,
|
155 |
+
"Ġhair": 104,
|
156 |
+
"Ġsha": 105,
|
157 |
+
"ple": 106,
|
158 |
+
"rple": 107,
|
159 |
+
"Ġpu": 108,
|
160 |
+
"Ġpurple": 109,
|
161 |
+
"ut": 110,
|
162 |
+
"en": 111,
|
163 |
+
"Ġgre": 112,
|
164 |
+
"Ġgreen": 113,
|
165 |
+
"Ġshad": 114,
|
166 |
+
"Ġblu": 115,
|
167 |
+
"Ġblue": 116,
|
168 |
+
"Ġmo": 117,
|
169 |
+
"looking": 118,
|
170 |
+
"li": 119,
|
171 |
+
"Ġr": 120,
|
172 |
+
"rown": 121,
|
173 |
+
"ti": 122,
|
174 |
+
"Ġli": 123,
|
175 |
+
"la": 124,
|
176 |
+
"ap": 125,
|
177 |
+
"Ġbear": 126,
|
178 |
+
"Ġbeard": 127,
|
179 |
+
"are": 128,
|
180 |
+
"Ġbrown": 129,
|
181 |
+
"wk": 130,
|
182 |
+
"hawk": 131,
|
183 |
+
"Ġmohawk": 132,
|
184 |
+
"ig": 133,
|
185 |
+
"ring": 134,
|
186 |
+
"Ġear": 135,
|
187 |
+
"Ġearring": 136,
|
188 |
+
"ed": 137,
|
189 |
+
"ey": 138,
|
190 |
+
"Ġey": 139,
|
191 |
+
"but": 140,
|
192 |
+
"ibut": 141,
|
193 |
+
"ttr": 142,
|
194 |
+
"Ġattr": 143,
|
195 |
+
"punky": 144,
|
196 |
+
"Ġattribut": 145,
|
197 |
+
"ps": 146,
|
198 |
+
"Ġattributes": 147,
|
199 |
+
"Ġcap": 148,
|
200 |
+
"ss": 149,
|
201 |
+
"tick": 150,
|
202 |
+
"Ġlips": 151,
|
203 |
+
"Ġlipstick": 152,
|
204 |
+
"Ġshades": 153,
|
205 |
+
"lass": 154,
|
206 |
+
"Ġn": 155,
|
207 |
+
"Ġeye": 156,
|
208 |
+
"Ġpunky": 157,
|
209 |
+
"Ġrare": 158,
|
210 |
+
"Ġho": 159,
|
211 |
+
"ow": 160,
|
212 |
+
"Ġglass": 161,
|
213 |
+
"Ġglasses": 162,
|
214 |
+
"tt": 163,
|
215 |
+
"Ġshadow": 164,
|
216 |
+
"Ġ3": 165,
|
217 |
+
"Ġgo": 166,
|
218 |
+
"or": 167,
|
219 |
+
"Ġd": 168,
|
220 |
+
"mal": 169,
|
221 |
+
"Ġpi": 170,
|
222 |
+
"Ġstr": 171,
|
223 |
+
"he": 172,
|
224 |
+
"Ġclo": 173,
|
225 |
+
"Ġclown": 174,
|
226 |
+
"ke": 175,
|
227 |
+
"od": 176,
|
228 |
+
"ark": 177,
|
229 |
+
"Ġdark": 178,
|
230 |
+
"ld": 179,
|
231 |
+
"ce": 180,
|
232 |
+
"Ġcig": 181,
|
233 |
+
"arett": 182,
|
234 |
+
"Ġcigarett": 183,
|
235 |
+
"Ġcigarette": 184,
|
236 |
+
"lack": 185,
|
237 |
+
"Ġblack": 186,
|
238 |
+
"and": 187,
|
239 |
+
"Ġnor": 188,
|
240 |
+
"Ġnormal": 189,
|
241 |
+
"Ġ2": 190,
|
242 |
+
"nt": 191,
|
243 |
+
"ront": 192,
|
244 |
+
"Ġfront": 193,
|
245 |
+
"Ġlooking": 194,
|
246 |
+
"car": 195,
|
247 |
+
"cut": 196,
|
248 |
+
"ela": 197,
|
249 |
+
"on": 198,
|
250 |
+
"olu": 199,
|
251 |
+
"ted": 200,
|
252 |
+
"up": 201,
|
253 |
+
"xela": 202,
|
254 |
+
"Ġlo": 203,
|
255 |
+
"Ġlook": 204,
|
256 |
+
"Ġup": 205,
|
257 |
+
"Ġsing": 206,
|
258 |
+
"Ġscar": 207,
|
259 |
+
"esolu": 208,
|
260 |
+
"how": 209,
|
261 |
+
"Ġresolu": 210,
|
262 |
+
"tion": 211,
|
263 |
+
"Ġlike": 212,
|
264 |
+
"Ġgood": 213,
|
265 |
+
"Ġpixela": 214,
|
266 |
+
"cute": 215,
|
267 |
+
"Ġlow": 216,
|
268 |
+
"Ġsingle": 217,
|
269 |
+
"Ġscarce": 218,
|
270 |
+
"Ġresolution": 219,
|
271 |
+
"Ġpixelated": 220,
|
272 |
+
"fu": 221,
|
273 |
+
"nn": 222,
|
274 |
+
"funn": 223,
|
275 |
+
"funny": 224,
|
276 |
+
"Ġeyes": 225,
|
277 |
+
"Ġhe": 226,
|
278 |
+
"at": 227,
|
279 |
+
"Ġv": 228,
|
280 |
+
"aig": 229,
|
281 |
+
"ht": 230,
|
282 |
+
"Ġstraig": 231,
|
283 |
+
"Ġstraight": 232,
|
284 |
+
"er": 233,
|
285 |
+
"Ġwild": 234,
|
286 |
+
"ad": 235,
|
287 |
+
"Ġhead": 236,
|
288 |
+
"Ġhot": 237,
|
289 |
+
"Ġbig": 238,
|
290 |
+
"ic": 239,
|
291 |
+
"Ġre": 240,
|
292 |
+
"Ġmole": 241,
|
293 |
+
"an": 242,
|
294 |
+
"mp": 243,
|
295 |
+
"sy": 244,
|
296 |
+
"us": 245,
|
297 |
+
"Ġner": 246,
|
298 |
+
"Ġnerd": 247,
|
299 |
+
"nde": 248,
|
300 |
+
"Ġblo": 249,
|
301 |
+
"Ġblonde": 250,
|
302 |
+
"im": 251,
|
303 |
+
"ned": 252,
|
304 |
+
"rned": 253,
|
305 |
+
"Ġrim": 254,
|
306 |
+
"Ġhorned": 255,
|
307 |
+
"Ġhat": 256,
|
308 |
+
"gu": 257,
|
309 |
+
"lar": 258,
|
310 |
+
"Ġregu": 259,
|
311 |
+
"Ġregular": 260,
|
312 |
+
"Ġclass": 261,
|
313 |
+
"Ġclassic": 262,
|
314 |
+
"Ġband": 263,
|
315 |
+
"ana": 264,
|
316 |
+
"Ġbandana": 265,
|
317 |
+
"sk": 266,
|
318 |
+
"Ġmask": 267,
|
319 |
+
"ingy": 268,
|
320 |
+
"Ġstringy": 269,
|
321 |
+
"ch": 270,
|
322 |
+
"Ġpat": 271,
|
323 |
+
"Ġpatch": 272,
|
324 |
+
"essy": 273,
|
325 |
+
"Ġmessy": 274,
|
326 |
+
"ved": 275,
|
327 |
+
"Ġshaved": 276,
|
328 |
+
"ru": 277,
|
329 |
+
"Ġfru": 278,
|
330 |
+
"mpy": 279,
|
331 |
+
"Ġfrumpy": 280,
|
332 |
+
"Ġth": 281,
|
333 |
+
"Ġthin": 282,
|
334 |
+
"Ġsp": 283,
|
335 |
+
"itt": 284,
|
336 |
+
"kn": 285,
|
337 |
+
"Ġkn": 286,
|
338 |
+
"itted": 287,
|
339 |
+
"Ġknitted": 288,
|
340 |
+
"az": 289,
|
341 |
+
"Ġcraz": 290,
|
342 |
+
"Ġcrazy": 291,
|
343 |
+
"band": 292,
|
344 |
+
"Ġheadband": 293,
|
345 |
+
"ie": 294,
|
346 |
+
"ta": 295,
|
347 |
+
"Ġsmal": 296,
|
348 |
+
"Ġsmall": 297,
|
349 |
+
"pe": 298,
|
350 |
+
"Ġvr": 299,
|
351 |
+
"Ġ4": 300,
|
352 |
+
"hain": 301,
|
353 |
+
"Ġchain": 302,
|
354 |
+
"Ġpipe": 303,
|
355 |
+
"ak": 304,
|
356 |
+
"cho": 305,
|
357 |
+
"eak": 306,
|
358 |
+
"ike": 307,
|
359 |
+
"ncho": 308,
|
360 |
+
"toncho": 309,
|
361 |
+
"Ġpeak": 310,
|
362 |
+
"Ġmut": 311,
|
363 |
+
"Ġspike": 312,
|
364 |
+
"tonchops": 313,
|
365 |
+
"Ġmuttonchops": 314,
|
366 |
+
"ag": 315,
|
367 |
+
"rag": 316,
|
368 |
+
"Ġdo": 317,
|
369 |
+
"Ġgoat": 318,
|
370 |
+
"che": 319,
|
371 |
+
"Ġmus": 320,
|
372 |
+
"tache": 321,
|
373 |
+
"Ġmustache": 322,
|
374 |
+
"ur": 323,
|
375 |
+
"io": 324,
|
376 |
+
"xur": 325,
|
377 |
+
"Ġlu": 326,
|
378 |
+
"ious": 327,
|
379 |
+
"xurious": 328,
|
380 |
+
"Ġluxurious": 329,
|
381 |
+
"hin": 330,
|
382 |
+
"str": 331,
|
383 |
+
"Ġchin": 332,
|
384 |
+
"strap": 333,
|
385 |
+
"Ġchinstrap": 334,
|
386 |
+
"ape": 335,
|
387 |
+
"Ġvape": 336,
|
388 |
+
"bar": 337,
|
389 |
+
"ndle": 338,
|
390 |
+
"Ġhandle": 339,
|
391 |
+
"bars": 340,
|
392 |
+
"Ġhandlebars": 341,
|
393 |
+
"Ġfrown": 342,
|
394 |
+
"Ġhood": 343,
|
395 |
+
"Ġhoodie": 344,
|
396 |
+
"war": 345,
|
397 |
+
"Ġfor": 346,
|
398 |
+
"ward": 347,
|
399 |
+
"Ġforward": 348,
|
400 |
+
"il": 349,
|
401 |
+
"ile": 350,
|
402 |
+
"mile": 351,
|
403 |
+
"Ġsmile": 352,
|
404 |
+
"Ġno": 353,
|
405 |
+
"se": 354,
|
406 |
+
"Ġnose": 355,
|
407 |
+
"oli": 356,
|
408 |
+
"Ġpoli": 357,
|
409 |
+
"Ġpolice": 358,
|
410 |
+
"dor": 359,
|
411 |
+
"Ġfedor": 360,
|
412 |
+
"Ġfedora": 361,
|
413 |
+
"ass": 362,
|
414 |
+
"Ġtass": 363,
|
415 |
+
"Ġtassle": 364,
|
416 |
+
"al": 365,
|
417 |
+
"Ġmed": 366,
|
418 |
+
"ical": 367,
|
419 |
+
"Ġmedical": 368,
|
420 |
+
"Ġgold": 369,
|
421 |
+
"ver": 370,
|
422 |
+
"Ġsil": 371,
|
423 |
+
"Ġsilver": 372,
|
424 |
+
"amp": 373,
|
425 |
+
"ire": 374,
|
426 |
+
"lf": 375,
|
427 |
+
"ob": 376,
|
428 |
+
"Ġbob": 377,
|
429 |
+
"Ġhalf": 378,
|
430 |
+
"Ġvamp": 379,
|
431 |
+
"Ġred": 380,
|
432 |
+
"Ġvampire": 381,
|
433 |
+
"bo": 382,
|
434 |
+
"Ġcow": 383,
|
435 |
+
"boy": 384,
|
436 |
+
"Ġcowboy": 385,
|
437 |
+
"hi": 386,
|
438 |
+
"te": 387,
|
439 |
+
"Ġwhi": 388,
|
440 |
+
"Ġwhite": 389,
|
441 |
+
"rt": 390,
|
442 |
+
"Ġsho": 391,
|
443 |
+
"Ġshort": 392,
|
444 |
+
"ek": 393,
|
445 |
+
"Ġro": 394,
|
446 |
+
"Ġche": 395,
|
447 |
+
"eks": 396,
|
448 |
+
"Ġrosy": 397,
|
449 |
+
"Ġcheeks": 398,
|
450 |
+
"ot": 399,
|
451 |
+
"Ġspot": 400,
|
452 |
+
"Ġspots": 401,
|
453 |
+
"Ġto": 402,
|
454 |
+
"Ġtop": 403,
|
455 |
+
"Ġpink": 404,
|
456 |
+
"Ġpig": 405,
|
457 |
+
"tail": 406,
|
458 |
+
"Ġpigtail": 407,
|
459 |
+
"Ġpigtails": 408,
|
460 |
+
"Ġz": 409,
|
461 |
+
"bie": 410,
|
462 |
+
"mbie": 411,
|
463 |
+
"ombie": 412,
|
464 |
+
"Ġzombie": 413,
|
465 |
+
"eld": 414,
|
466 |
+
"gg": 415,
|
467 |
+
"les": 416,
|
468 |
+
"Ġweld": 417,
|
469 |
+
"Ġgogg": 418,
|
470 |
+
"Ġwelding": 419,
|
471 |
+
"Ġgoggles": 420,
|
472 |
+
"ee": 421,
|
473 |
+
"uck": 422,
|
474 |
+
"Ġbuck": 423,
|
475 |
+
"Ġtee": 424,
|
476 |
+
"Ġteeth": 425,
|
477 |
+
"Ġ1": 426,
|
478 |
+
"ge": 427,
|
479 |
+
"ide": 428,
|
480 |
+
"ran": 429,
|
481 |
+
"Ġside": 430,
|
482 |
+
"Ġoran": 431,
|
483 |
+
"Ġorange": 432,
|
484 |
+
"Ġattribute": 433,
|
485 |
+
"iar": 434,
|
486 |
+
"Ġtiar": 435,
|
487 |
+
"Ġtiara": 436,
|
488 |
+
"et": 437,
|
489 |
+
"lm": 438,
|
490 |
+
"lot": 439,
|
491 |
+
"Ġpilot": 440,
|
492 |
+
"Ġhelm": 441,
|
493 |
+
"Ġhelmet": 442,
|
494 |
+
"Ġcho": 443,
|
495 |
+
"ker": 444,
|
496 |
+
"Ġchoker": 445,
|
497 |
+
"ean": 446,
|
498 |
+
"Ġbean": 447,
|
499 |
+
"Ġbeanie": 448,
|
500 |
+
"Ġ5": 449,
|
501 |
+
"Ġape": 450,
|
502 |
+
"Ġali": 451,
|
503 |
+
"Ġalien": 452,
|
504 |
+
"Ġ6": 453,
|
505 |
+
"imple": 454,
|
506 |
+
"Ġ0": 455,
|
507 |
+
"Ġsimple": 456,
|
508 |
+
"Ġfeat": 457,
|
509 |
+
"ures": 458,
|
510 |
+
"Ġfeatures": 459,
|
511 |
+
"ace": 460,
|
512 |
+
"ase": 461,
|
513 |
+
"ero": 462,
|
514 |
+
"lin": 463,
|
515 |
+
"simple": 464,
|
516 |
+
"Ġaver": 465,
|
517 |
+
"Ġbut": 466,
|
518 |
+
"Ġbare": 467,
|
519 |
+
"Ġbase": 468,
|
520 |
+
"thing": 469,
|
521 |
+
"Ġface": 470,
|
522 |
+
"age": 471,
|
523 |
+
"Ġnothing": 472,
|
524 |
+
"Ġzero": 473,
|
525 |
+
"line": 474,
|
526 |
+
"Ġaverage": 475,
|
527 |
+
"Ġ7": 476
|
528 |
+
},
|
529 |
+
"merges": [
|
530 |
+
"Ġ a",
|
531 |
+
"n d",
|
532 |
+
"Ġ b",
|
533 |
+
"h a",
|
534 |
+
"l e",
|
535 |
+
"m a",
|
536 |
+
"Ġ c",
|
537 |
+
"r o",
|
538 |
+
"p u",
|
539 |
+
"c k",
|
540 |
+
"t o",
|
541 |
+
"Ġa nd",
|
542 |
+
"a ck",
|
543 |
+
"a r",
|
544 |
+
"Ġ ma",
|
545 |
+
"n k",
|
546 |
+
"g ro",
|
547 |
+
"u nd",
|
548 |
+
"Ġb ack",
|
549 |
+
"pu nk",
|
550 |
+
"gro und",
|
551 |
+
"Ġback ground",
|
552 |
+
"Ġ ha",
|
553 |
+
"Ġc r",
|
554 |
+
"i n",
|
555 |
+
"Ġ s",
|
556 |
+
"Ġ o",
|
557 |
+
"p to",
|
558 |
+
"y pto",
|
559 |
+
"Ġcr ypto",
|
560 |
+
"Ġcrypto punk",
|
561 |
+
"Ġo f",
|
562 |
+
"e s",
|
563 |
+
"Ġ w",
|
564 |
+
"Ġw i",
|
565 |
+
"t h",
|
566 |
+
"in g",
|
567 |
+
"h o",
|
568 |
+
"l o",
|
569 |
+
"Ġwi th",
|
570 |
+
"Ġ p",
|
571 |
+
"Ġma le",
|
572 |
+
"Ġ g",
|
573 |
+
"Ġ f",
|
574 |
+
"e ar",
|
575 |
+
"Ġha s",
|
576 |
+
"Ġ m",
|
577 |
+
"l u",
|
578 |
+
"Ġ t",
|
579 |
+
"r e",
|
580 |
+
"Ġf e",
|
581 |
+
"d e",
|
582 |
+
"w n",
|
583 |
+
"o k",
|
584 |
+
"ho to",
|
585 |
+
"lo ok",
|
586 |
+
"Ġp hoto",
|
587 |
+
"ha t",
|
588 |
+
"Ġt hat",
|
589 |
+
"Ġma de",
|
590 |
+
"ma le",
|
591 |
+
"Ġfe male",
|
592 |
+
"t r",
|
593 |
+
"i r",
|
594 |
+
"Ġha ir",
|
595 |
+
"Ġs ha",
|
596 |
+
"p le",
|
597 |
+
"r ple",
|
598 |
+
"Ġ pu",
|
599 |
+
"Ġpu rple",
|
600 |
+
"u t",
|
601 |
+
"e n",
|
602 |
+
"Ġg re",
|
603 |
+
"Ġgre en",
|
604 |
+
"Ġsha d",
|
605 |
+
"Ġb lu",
|
606 |
+
"Ġblu e",
|
607 |
+
"Ġm o",
|
608 |
+
"look ing",
|
609 |
+
"l i",
|
610 |
+
"Ġ r",
|
611 |
+
"ro wn",
|
612 |
+
"t i",
|
613 |
+
"Ġ li",
|
614 |
+
"l a",
|
615 |
+
"a p",
|
616 |
+
"Ġb ear",
|
617 |
+
"Ġbear d",
|
618 |
+
"ar e",
|
619 |
+
"Ġb rown",
|
620 |
+
"w k",
|
621 |
+
"ha wk",
|
622 |
+
"Ġmo hawk",
|
623 |
+
"i g",
|
624 |
+
"r ing",
|
625 |
+
"Ġ ear",
|
626 |
+
"Ġear ring",
|
627 |
+
"e d",
|
628 |
+
"e y",
|
629 |
+
"Ġ ey",
|
630 |
+
"b ut",
|
631 |
+
"i but",
|
632 |
+
"t tr",
|
633 |
+
"Ġa ttr",
|
634 |
+
"punk y",
|
635 |
+
"Ġattr ibut",
|
636 |
+
"p s",
|
637 |
+
"Ġattribut es",
|
638 |
+
"Ġc ap",
|
639 |
+
"s s",
|
640 |
+
"ti ck",
|
641 |
+
"Ġli ps",
|
642 |
+
"Ġlips tick",
|
643 |
+
"Ġshad es",
|
644 |
+
"la ss",
|
645 |
+
"Ġ n",
|
646 |
+
"Ġey e",
|
647 |
+
"Ġ punky",
|
648 |
+
"Ġr are",
|
649 |
+
"Ġ ho",
|
650 |
+
"o w",
|
651 |
+
"Ġg lass",
|
652 |
+
"Ġglass es",
|
653 |
+
"t t",
|
654 |
+
"Ġshad ow",
|
655 |
+
"Ġ 3",
|
656 |
+
"Ġg o",
|
657 |
+
"o r",
|
658 |
+
"Ġ d",
|
659 |
+
"ma l",
|
660 |
+
"Ġp i",
|
661 |
+
"Ġs tr",
|
662 |
+
"h e",
|
663 |
+
"Ġc lo",
|
664 |
+
"Ġclo wn",
|
665 |
+
"k e",
|
666 |
+
"o d",
|
667 |
+
"ar k",
|
668 |
+
"Ġd ark",
|
669 |
+
"l d",
|
670 |
+
"c e",
|
671 |
+
"Ġc ig",
|
672 |
+
"are tt",
|
673 |
+
"Ġcig arett",
|
674 |
+
"Ġcigarett e",
|
675 |
+
"l ack",
|
676 |
+
"Ġb lack",
|
677 |
+
"a nd",
|
678 |
+
"Ġn or",
|
679 |
+
"Ġnor mal",
|
680 |
+
"Ġ 2",
|
681 |
+
"n t",
|
682 |
+
"ro nt",
|
683 |
+
"Ġf ront",
|
684 |
+
"Ġ looking",
|
685 |
+
"c ar",
|
686 |
+
"c ut",
|
687 |
+
"e la",
|
688 |
+
"o n",
|
689 |
+
"o lu",
|
690 |
+
"t ed",
|
691 |
+
"u p",
|
692 |
+
"x ela",
|
693 |
+
"Ġ lo",
|
694 |
+
"Ġ look",
|
695 |
+
"Ġ up",
|
696 |
+
"Ġs ing",
|
697 |
+
"Ġs car",
|
698 |
+
"es olu",
|
699 |
+
"ho w",
|
700 |
+
"Ġr esolu",
|
701 |
+
"ti on",
|
702 |
+
"Ġli ke",
|
703 |
+
"Ġgo od",
|
704 |
+
"Ġpi xela",
|
705 |
+
"cut e",
|
706 |
+
"Ġlo w",
|
707 |
+
"Ġsing le",
|
708 |
+
"Ġscar ce",
|
709 |
+
"Ġresolu tion",
|
710 |
+
"Ġpixela ted",
|
711 |
+
"f u",
|
712 |
+
"n n",
|
713 |
+
"fu nn",
|
714 |
+
"funn y",
|
715 |
+
"Ġey es",
|
716 |
+
"Ġ he",
|
717 |
+
"a t",
|
718 |
+
"Ġ v",
|
719 |
+
"a ig",
|
720 |
+
"h t",
|
721 |
+
"Ġstr aig",
|
722 |
+
"Ġstraig ht",
|
723 |
+
"e r",
|
724 |
+
"Ġwi ld",
|
725 |
+
"a d",
|
726 |
+
"Ġhe ad",
|
727 |
+
"Ġho t",
|
728 |
+
"Ġb ig",
|
729 |
+
"i c",
|
730 |
+
"Ġ re",
|
731 |
+
"Ġmo le",
|
732 |
+
"a n",
|
733 |
+
"m p",
|
734 |
+
"s y",
|
735 |
+
"u s",
|
736 |
+
"Ġn er",
|
737 |
+
"Ġner d",
|
738 |
+
"nd e",
|
739 |
+
"Ġb lo",
|
740 |
+
"Ġblo nde",
|
741 |
+
"i m",
|
742 |
+
"n ed",
|
743 |
+
"r ned",
|
744 |
+
"Ġr im",
|
745 |
+
"Ġho rned",
|
746 |
+
"Ġha t",
|
747 |
+
"g u",
|
748 |
+
"l ar",
|
749 |
+
"Ġre gu",
|
750 |
+
"Ġregu lar",
|
751 |
+
"Ġc lass",
|
752 |
+
"Ġclass ic",
|
753 |
+
"Ġb and",
|
754 |
+
"an a",
|
755 |
+
"Ġband ana",
|
756 |
+
"s k",
|
757 |
+
"Ġma sk",
|
758 |
+
"ing y",
|
759 |
+
"Ġstr ingy",
|
760 |
+
"c h",
|
761 |
+
"Ġp at",
|
762 |
+
"Ġpat ch",
|
763 |
+
"es sy",
|
764 |
+
"Ġm essy",
|
765 |
+
"v ed",
|
766 |
+
"Ġsha ved",
|
767 |
+
"r u",
|
768 |
+
"Ġf ru",
|
769 |
+
"mp y",
|
770 |
+
"Ġfru mpy",
|
771 |
+
"Ġ th",
|
772 |
+
"Ġth in",
|
773 |
+
"Ġs p",
|
774 |
+
"i tt",
|
775 |
+
"k n",
|
776 |
+
"Ġ kn",
|
777 |
+
"itt ed",
|
778 |
+
"Ġkn itted",
|
779 |
+
"a z",
|
780 |
+
"Ġcr az",
|
781 |
+
"Ġcraz y",
|
782 |
+
"b and",
|
783 |
+
"Ġhead band",
|
784 |
+
"i e",
|
785 |
+
"t a",
|
786 |
+
"Ġs mal",
|
787 |
+
"Ġsmal l",
|
788 |
+
"p e",
|
789 |
+
"Ġv r",
|
790 |
+
"Ġ 4",
|
791 |
+
"ha in",
|
792 |
+
"Ġc hain",
|
793 |
+
"Ġpi pe",
|
794 |
+
"a k",
|
795 |
+
"c ho",
|
796 |
+
"e ak",
|
797 |
+
"i ke",
|
798 |
+
"n cho",
|
799 |
+
"to ncho",
|
800 |
+
"Ġp eak",
|
801 |
+
"Ġm ut",
|
802 |
+
"Ġsp ike",
|
803 |
+
"toncho ps",
|
804 |
+
"Ġmut tonchops",
|
805 |
+
"a g",
|
806 |
+
"r ag",
|
807 |
+
"Ġd o",
|
808 |
+
"Ġgo at",
|
809 |
+
"c he",
|
810 |
+
"Ġm us",
|
811 |
+
"ta che",
|
812 |
+
"Ġmus tache",
|
813 |
+
"u r",
|
814 |
+
"i o",
|
815 |
+
"x ur",
|
816 |
+
"Ġ lu",
|
817 |
+
"io us",
|
818 |
+
"xur ious",
|
819 |
+
"Ġlu xurious",
|
820 |
+
"h in",
|
821 |
+
"s tr",
|
822 |
+
"Ġc hin",
|
823 |
+
"str ap",
|
824 |
+
"Ġchin strap",
|
825 |
+
"ap e",
|
826 |
+
"Ġv ape",
|
827 |
+
"b ar",
|
828 |
+
"nd le",
|
829 |
+
"Ġha ndle",
|
830 |
+
"bar s",
|
831 |
+
"Ġhandle bars",
|
832 |
+
"Ġf rown",
|
833 |
+
"Ġho od",
|
834 |
+
"Ġhood ie",
|
835 |
+
"w ar",
|
836 |
+
"Ġf or",
|
837 |
+
"war d",
|
838 |
+
"Ġfor ward",
|
839 |
+
"i l",
|
840 |
+
"i le",
|
841 |
+
"m ile",
|
842 |
+
"Ġs mile",
|
843 |
+
"Ġn o",
|
844 |
+
"s e",
|
845 |
+
"Ġno se",
|
846 |
+
"o li",
|
847 |
+
"Ġp oli",
|
848 |
+
"Ġpoli ce",
|
849 |
+
"d or",
|
850 |
+
"Ġfe dor",
|
851 |
+
"Ġfedor a",
|
852 |
+
"a ss",
|
853 |
+
"Ġt ass",
|
854 |
+
"Ġtass le",
|
855 |
+
"a l",
|
856 |
+
"Ġm ed",
|
857 |
+
"ic al",
|
858 |
+
"Ġmed ical",
|
859 |
+
"Ġgo ld",
|
860 |
+
"v er",
|
861 |
+
"Ġs il",
|
862 |
+
"Ġsil ver",
|
863 |
+
"a mp",
|
864 |
+
"i re",
|
865 |
+
"l f",
|
866 |
+
"o b",
|
867 |
+
"Ġb ob",
|
868 |
+
"Ġha lf",
|
869 |
+
"Ġv amp",
|
870 |
+
"Ġre d",
|
871 |
+
"Ġvamp ire",
|
872 |
+
"b o",
|
873 |
+
"Ġc ow",
|
874 |
+
"bo y",
|
875 |
+
"Ġcow boy",
|
876 |
+
"h i",
|
877 |
+
"t e",
|
878 |
+
"Ġw hi",
|
879 |
+
"Ġwhi te",
|
880 |
+
"r t",
|
881 |
+
"Ġs ho",
|
882 |
+
"Ġsho rt",
|
883 |
+
"e k",
|
884 |
+
"Ġ ro",
|
885 |
+
"Ġc he",
|
886 |
+
"ek s",
|
887 |
+
"Ġro sy",
|
888 |
+
"Ġche eks",
|
889 |
+
"o t",
|
890 |
+
"Ġsp ot",
|
891 |
+
"Ġspot s",
|
892 |
+
"Ġ to",
|
893 |
+
"Ġto p",
|
894 |
+
"Ġpi nk",
|
895 |
+
"Ġp ig",
|
896 |
+
"ta il",
|
897 |
+
"Ġpig tail",
|
898 |
+
"Ġpigtail s",
|
899 |
+
"Ġ z",
|
900 |
+
"b ie",
|
901 |
+
"m bie",
|
902 |
+
"o mbie",
|
903 |
+
"Ġz ombie",
|
904 |
+
"e ld",
|
905 |
+
"g g",
|
906 |
+
"le s",
|
907 |
+
"Ġw eld",
|
908 |
+
"Ġgo gg",
|
909 |
+
"Ġweld ing",
|
910 |
+
"Ġgogg les",
|
911 |
+
"e e",
|
912 |
+
"u ck",
|
913 |
+
"Ġb uck",
|
914 |
+
"Ġt ee",
|
915 |
+
"Ġtee th",
|
916 |
+
"Ġ 1",
|
917 |
+
"g e",
|
918 |
+
"i de",
|
919 |
+
"r an",
|
920 |
+
"Ġs ide",
|
921 |
+
"Ġo ran",
|
922 |
+
"Ġoran ge",
|
923 |
+
"Ġattribut e",
|
924 |
+
"i ar",
|
925 |
+
"Ġt iar",
|
926 |
+
"Ġtiar a",
|
927 |
+
"e t",
|
928 |
+
"l m",
|
929 |
+
"lo t",
|
930 |
+
"Ġpi lot",
|
931 |
+
"Ġhe lm",
|
932 |
+
"Ġhelm et",
|
933 |
+
"Ġc ho",
|
934 |
+
"ke r",
|
935 |
+
"Ġcho ker",
|
936 |
+
"e an",
|
937 |
+
"Ġb ean",
|
938 |
+
"Ġbean ie",
|
939 |
+
"Ġ 5",
|
940 |
+
"Ġa pe",
|
941 |
+
"Ġa li",
|
942 |
+
"Ġali en",
|
943 |
+
"Ġ 6",
|
944 |
+
"im ple",
|
945 |
+
"Ġ 0",
|
946 |
+
"Ġs imple",
|
947 |
+
"Ġfe at",
|
948 |
+
"ur es",
|
949 |
+
"Ġfeat ures",
|
950 |
+
"a ce",
|
951 |
+
"a se",
|
952 |
+
"e ro",
|
953 |
+
"l in",
|
954 |
+
"s imple",
|
955 |
+
"Ġa ver",
|
956 |
+
"Ġb ut",
|
957 |
+
"Ġb are",
|
958 |
+
"Ġb ase",
|
959 |
+
"th ing",
|
960 |
+
"Ġf ace",
|
961 |
+
"ag e",
|
962 |
+
"Ġno thing",
|
963 |
+
"Ġz ero",
|
964 |
+
"lin e",
|
965 |
+
"Ġaver age",
|
966 |
+
"Ġ 7"
|
967 |
+
]
|
968 |
+
}
|
969 |
+
}
|
text2punks/data/codebook.pt
ADDED
Binary file (1.39 kB). View file
|
|
text2punks/loader.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from PIL import Image, UnidentifiedImageError
|
3 |
+
|
4 |
+
from pathlib import Path
|
5 |
+
from random import randint, choice
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
|
10 |
+
|
11 |
+
class TextImageDataset(Dataset):
|
12 |
+
def __init__(self,
|
13 |
+
folder,
|
14 |
+
text_len=40,
|
15 |
+
truncate_captions=False,
|
16 |
+
text_tokenizer=None,
|
17 |
+
image_tokenizer=None,
|
18 |
+
shuffle=False
|
19 |
+
):
|
20 |
+
"""
|
21 |
+
@param folder: Folder containing images and text files matched by their paths' respective "stem"
|
22 |
+
@param truncate_captions: Rather than throw an exception, captions which are too long will be truncated.
|
23 |
+
"""
|
24 |
+
super().__init__()
|
25 |
+
self.shuffle = shuffle
|
26 |
+
path = Path(folder)
|
27 |
+
|
28 |
+
text_files = [*path.glob('**/*.txt')]
|
29 |
+
image_files = [
|
30 |
+
*path.glob('**/*.png'), *path.glob('**/*.jpg'),
|
31 |
+
*path.glob('**/*.jpeg'), *path.glob('**/*.bmp')
|
32 |
+
]
|
33 |
+
|
34 |
+
text_files = {text_file.stem: text_file for text_file in text_files}
|
35 |
+
image_files = {image_file.stem: image_file for image_file in image_files}
|
36 |
+
|
37 |
+
keys = (image_files.keys() & text_files.keys())
|
38 |
+
|
39 |
+
self.keys = list(keys)
|
40 |
+
self.text_files = {k: v for k, v in text_files.items() if k in keys}
|
41 |
+
self.image_files = {k: v for k, v in image_files.items() if k in keys}
|
42 |
+
self.text_len = text_len
|
43 |
+
self.truncate_captions = truncate_captions
|
44 |
+
self.text_tokenizer = text_tokenizer
|
45 |
+
self.image_tokenizer = image_tokenizer
|
46 |
+
|
47 |
+
|
48 |
+
def __len__(self):
|
49 |
+
return len(self.keys)
|
50 |
+
|
51 |
+
def random_sample(self):
|
52 |
+
return self.__getitem__(randint(0, self.__len__() - 1))
|
53 |
+
|
54 |
+
def sequential_sample(self, ind):
|
55 |
+
if ind >= self.__len__() - 1:
|
56 |
+
return self.__getitem__(0)
|
57 |
+
return self.__getitem__(ind + 1)
|
58 |
+
|
59 |
+
def skip_sample(self, ind):
|
60 |
+
if self.shuffle:
|
61 |
+
return self.random_sample()
|
62 |
+
return self.sequential_sample(ind=ind)
|
63 |
+
|
64 |
+
def __getitem__(self, ind):
|
65 |
+
key = self.keys[ind]
|
66 |
+
|
67 |
+
text_file = self.text_files[key]
|
68 |
+
image_file = self.image_files[key]
|
69 |
+
|
70 |
+
descriptions = text_file.read_text().split('\n')
|
71 |
+
descriptions = list(filter(lambda t: len(t) > 0, descriptions))
|
72 |
+
try:
|
73 |
+
description = choice(descriptions)
|
74 |
+
except IndexError as zero_captions_in_file_ex:
|
75 |
+
print(f"An exception occurred trying to load file {text_file}.")
|
76 |
+
print(f"Skipping index {ind}")
|
77 |
+
return self.skip_sample(ind)
|
78 |
+
|
79 |
+
tokenized_text = self.text_tokenizer.tokenize(
|
80 |
+
description,
|
81 |
+
self.text_len,
|
82 |
+
truncate_text=self.truncate_captions
|
83 |
+
).squeeze(0)
|
84 |
+
try:
|
85 |
+
image = Image.open(image_file).convert('RGB')
|
86 |
+
pixels = np.array(image).reshape(-1, 3)
|
87 |
+
|
88 |
+
tokenized_image = [self.image_tokenizer[str(idx)] for idx in pixels]
|
89 |
+
tokenized_image = torch.tensor(tokenized_image)
|
90 |
+
except (UnidentifiedImageError, OSError) as corrupt_image_exceptions:
|
91 |
+
print(f"An exception occurred trying to load file {image_file}.")
|
92 |
+
print(f"Skipping index {ind}")
|
93 |
+
return self.skip_sample(ind)
|
94 |
+
|
95 |
+
# Success
|
96 |
+
return tokenized_text, tokenized_image
|
text2punks/text2punk.py
ADDED
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
from einops import rearrange, repeat
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import nn, einsum
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from axial_positional_embedding import AxialPositionalEmbedding
|
10 |
+
from text2punks.transformer import Transformer
|
11 |
+
|
12 |
+
|
13 |
+
# helpers fns
|
14 |
+
|
15 |
+
def exists(val):
|
16 |
+
return val is not None
|
17 |
+
|
18 |
+
def default(val, d):
|
19 |
+
return val if exists(val) else d
|
20 |
+
|
21 |
+
def set_requires_grad(model, value):
|
22 |
+
for param in model.parameters():
|
23 |
+
param.requires_grad = value
|
24 |
+
|
25 |
+
def eval_decorator(fn):
|
26 |
+
def inner(model, *args, **kwargs):
|
27 |
+
was_training = model.training
|
28 |
+
model.eval()
|
29 |
+
out = fn(model, *args, **kwargs)
|
30 |
+
model.train(was_training)
|
31 |
+
return out
|
32 |
+
return inner
|
33 |
+
|
34 |
+
# sampling helpers fn
|
35 |
+
|
36 |
+
def top_k(logits, thres = 0.5):
|
37 |
+
num_logits = logits.shape[-1]
|
38 |
+
k = max(int((1 - thres) * num_logits), 1)
|
39 |
+
val, ind = torch.topk(logits, k)
|
40 |
+
probs = torch.full_like(logits, float('-inf'))
|
41 |
+
probs.scatter_(1, ind, val)
|
42 |
+
return probs
|
43 |
+
|
44 |
+
# main CLIP class
|
45 |
+
|
46 |
+
class CLIP(nn.Module):
|
47 |
+
def __init__(
|
48 |
+
self,
|
49 |
+
*,
|
50 |
+
dim_text = 512,
|
51 |
+
dim_image = 512,
|
52 |
+
dim_latent = 512,
|
53 |
+
num_text_tokens = 10000,
|
54 |
+
text_enc_depth = 6,
|
55 |
+
text_seq_len = 256,
|
56 |
+
text_heads = 8,
|
57 |
+
num_visual_tokens = 256,
|
58 |
+
visual_enc_depth = 6,
|
59 |
+
visual_image_seq_len = 256,
|
60 |
+
visual_image_size = 24,
|
61 |
+
visual_heads = 8,
|
62 |
+
attn_pdrop = 0.1,
|
63 |
+
resid_pdrop = 0.1,
|
64 |
+
embd_pdrop = 0.1,
|
65 |
+
ff_dropout = 0.1,
|
66 |
+
attn_types = None
|
67 |
+
):
|
68 |
+
super().__init__()
|
69 |
+
|
70 |
+
# Texts
|
71 |
+
|
72 |
+
self.text_emb = nn.Embedding(num_text_tokens, dim_text)
|
73 |
+
self.text_pos_emb = nn.Embedding(text_seq_len, dim_text)
|
74 |
+
|
75 |
+
self.text_transformer = Transformer(
|
76 |
+
dim = dim_text,
|
77 |
+
causal = False,
|
78 |
+
seq_len = text_seq_len,
|
79 |
+
depth = text_enc_depth,
|
80 |
+
heads = text_heads,
|
81 |
+
dim_head = dim_text // text_heads,
|
82 |
+
attn_dropout = attn_pdrop,
|
83 |
+
resid_dropout = resid_pdrop,
|
84 |
+
embd_dropout = embd_pdrop,
|
85 |
+
ff_dropout = ff_dropout,
|
86 |
+
attn_types = attn_types
|
87 |
+
)
|
88 |
+
|
89 |
+
self.text_ln = nn.LayerNorm(dim_text)
|
90 |
+
self.to_text_latent = nn.Linear(dim_text, dim_latent, bias = False)
|
91 |
+
|
92 |
+
# Images
|
93 |
+
|
94 |
+
self.image_emb = nn.Embedding(num_visual_tokens, dim_image)
|
95 |
+
self.image_pos_emb = nn.Embedding(visual_image_seq_len, dim_image)
|
96 |
+
|
97 |
+
self.visual_transformer = Transformer(
|
98 |
+
dim = dim_image,
|
99 |
+
causal = False,
|
100 |
+
seq_len = visual_image_seq_len,
|
101 |
+
depth = visual_enc_depth,
|
102 |
+
heads = visual_heads,
|
103 |
+
dim_head = dim_image // visual_heads,
|
104 |
+
attn_dropout = attn_pdrop,
|
105 |
+
resid_dropout = resid_pdrop,
|
106 |
+
embd_dropout = embd_pdrop,
|
107 |
+
ff_dropout = ff_dropout,
|
108 |
+
attn_types = attn_types,
|
109 |
+
image_size = visual_image_size,
|
110 |
+
)
|
111 |
+
|
112 |
+
self.image_ln = nn.LayerNorm(dim_image)
|
113 |
+
self.to_visual_latent = nn.Linear(dim_image, dim_latent, bias = False)
|
114 |
+
|
115 |
+
self.temperature = nn.Parameter(torch.ones([]) * math.log(1 / 0.07))
|
116 |
+
|
117 |
+
|
118 |
+
self.apply(self._init_weights)
|
119 |
+
|
120 |
+
def _init_weights(self, module):
|
121 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
122 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
123 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
124 |
+
module.bias.data.zero_()
|
125 |
+
elif isinstance(module, nn.LayerNorm):
|
126 |
+
module.bias.data.zero_()
|
127 |
+
module.weight.data.fill_(1.0)
|
128 |
+
|
129 |
+
def forward(
|
130 |
+
self,
|
131 |
+
text,
|
132 |
+
image,
|
133 |
+
return_loss = False
|
134 |
+
):
|
135 |
+
b, device= text.shape[0], text.device
|
136 |
+
|
137 |
+
text_emb = self.text_emb(text)
|
138 |
+
text_emb += self.text_pos_emb(torch.arange(text.shape[1], device = device))
|
139 |
+
|
140 |
+
image_emb = self.image_emb(image)
|
141 |
+
image_emb += self.image_pos_emb(torch.arange(image.shape[1], device = device))
|
142 |
+
|
143 |
+
enc_text = self.text_transformer(text_emb)
|
144 |
+
enc_image = self.visual_transformer(image_emb)
|
145 |
+
|
146 |
+
text_latents = enc_text.mean(dim = 1)
|
147 |
+
image_latents = enc_image.mean(dim = 1)
|
148 |
+
|
149 |
+
text_latents = self.text_ln(text_latents)
|
150 |
+
image_latents = self.image_ln(image_latents)
|
151 |
+
|
152 |
+
text_latents = self.to_text_latent(text_latents)
|
153 |
+
image_latents = self.to_visual_latent(image_latents)
|
154 |
+
|
155 |
+
text_latents, image_latents = map(lambda t: F.normalize(t, p = 2, dim = -1), (text_latents, image_latents))
|
156 |
+
|
157 |
+
temp = self.temperature.exp()
|
158 |
+
|
159 |
+
if not return_loss:
|
160 |
+
sim = einsum('n d, n d -> n', text_latents, image_latents) * temp
|
161 |
+
return sim
|
162 |
+
|
163 |
+
sim = einsum('i d, j d -> i j', text_latents, image_latents) * temp
|
164 |
+
labels = torch.arange(b, device = device)
|
165 |
+
loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
|
166 |
+
return loss
|
167 |
+
|
168 |
+
# main Text2Punks class
|
169 |
+
|
170 |
+
class Text2Punks(nn.Module):
|
171 |
+
def __init__(
|
172 |
+
self,
|
173 |
+
*,
|
174 |
+
n_embd,
|
175 |
+
n_layer = 12,
|
176 |
+
n_head = 12,
|
177 |
+
d_head = 64,
|
178 |
+
num_text_tokens = 10000,
|
179 |
+
text_seq_len = 256,
|
180 |
+
num_image_tokens = 222,
|
181 |
+
image_seq_len = 576,
|
182 |
+
image_size = 24,
|
183 |
+
attn_pdrop = 0.1,
|
184 |
+
resid_pdrop = 0.1,
|
185 |
+
embd_pdrop = 0.1,
|
186 |
+
ff_dropout = 0.1,
|
187 |
+
attn_types = None,
|
188 |
+
loss_img_weight = 7,
|
189 |
+
loss_txt_weight = 7,
|
190 |
+
):
|
191 |
+
super().__init__()
|
192 |
+
|
193 |
+
num_text_tokens = num_text_tokens + text_seq_len # reserve unique padding tokens for each position (text seq len)
|
194 |
+
|
195 |
+
self.text_emb = nn.Embedding(num_text_tokens, n_embd)
|
196 |
+
self.image_emb = nn.Embedding(num_image_tokens, n_embd)
|
197 |
+
|
198 |
+
self.text_pos_emb = nn.Embedding(text_seq_len + 1, n_embd) # +1 for <bos> a.k.a <sos>
|
199 |
+
# self.image_pos_emb = nn.Embedding(image_seq_len, n_embd)
|
200 |
+
self.image_pos_emb = nn.Parameter(torch.zeros(1, image_seq_len, n_embd))
|
201 |
+
# self.image_pos_emb = AxialPositionalEmbedding(n_embd, axial_shape=(image_size, image_size))
|
202 |
+
|
203 |
+
self.num_text_tokens = num_text_tokens # for offsetting logits index and calculating cross entropy loss
|
204 |
+
self.num_image_tokens = num_image_tokens
|
205 |
+
self.text_seq_len = text_seq_len
|
206 |
+
self.image_seq_len = image_seq_len
|
207 |
+
|
208 |
+
seq_len = text_seq_len + image_seq_len
|
209 |
+
total_tokens = num_text_tokens + num_image_tokens
|
210 |
+
self.total_seq_len = seq_len
|
211 |
+
self.total_tokens = total_tokens
|
212 |
+
|
213 |
+
self.transformer = Transformer(
|
214 |
+
dim = n_embd,
|
215 |
+
causal = True,
|
216 |
+
seq_len = seq_len,
|
217 |
+
depth = n_layer,
|
218 |
+
heads = n_head,
|
219 |
+
dim_head = d_head,
|
220 |
+
attn_dropout = attn_pdrop,
|
221 |
+
resid_dropout = resid_pdrop,
|
222 |
+
embd_dropout = embd_pdrop,
|
223 |
+
ff_dropout = ff_dropout,
|
224 |
+
attn_types = attn_types,
|
225 |
+
image_size = image_size,
|
226 |
+
)
|
227 |
+
|
228 |
+
self.to_logits = nn.Sequential(
|
229 |
+
nn.LayerNorm(n_embd),
|
230 |
+
nn.Linear(n_embd, self.total_tokens),
|
231 |
+
)
|
232 |
+
|
233 |
+
seq_range = torch.arange(seq_len)
|
234 |
+
logits_range = torch.arange(total_tokens)
|
235 |
+
|
236 |
+
seq_range = rearrange(seq_range, 'n -> () n ()')
|
237 |
+
logits_range = rearrange(logits_range, 'd -> () () d')
|
238 |
+
|
239 |
+
logits_mask = (
|
240 |
+
((seq_range >= text_seq_len) & (logits_range < num_text_tokens)) |
|
241 |
+
((seq_range < text_seq_len) & (logits_range >= num_text_tokens))
|
242 |
+
)
|
243 |
+
|
244 |
+
self.register_buffer('logits_mask', logits_mask, persistent=False)
|
245 |
+
self.loss_img_weight = loss_img_weight
|
246 |
+
self.loss_txt_weight = loss_txt_weight
|
247 |
+
|
248 |
+
self.apply(self._init_weights)
|
249 |
+
|
250 |
+
def _init_weights(self, module):
|
251 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
252 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
253 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
254 |
+
module.bias.data.zero_()
|
255 |
+
elif isinstance(module, nn.LayerNorm):
|
256 |
+
module.bias.data.zero_()
|
257 |
+
module.weight.data.fill_(1.0)
|
258 |
+
|
259 |
+
@torch.no_grad()
|
260 |
+
@eval_decorator
|
261 |
+
def generate_images(
|
262 |
+
self,
|
263 |
+
text,
|
264 |
+
decoder,
|
265 |
+
*,
|
266 |
+
clip = None,
|
267 |
+
filter_thres = 0.5,
|
268 |
+
temperature = 1.,
|
269 |
+
img = None,
|
270 |
+
num_init_img_tokens = None
|
271 |
+
):
|
272 |
+
text_seq_len, image_seq_len, num_text_tokens = self.text_seq_len, self.image_seq_len, self.num_text_tokens
|
273 |
+
total_len = text_seq_len + image_seq_len
|
274 |
+
|
275 |
+
batch = text.shape[0]
|
276 |
+
text = text[:, :text_seq_len] # make sure text is within bounds
|
277 |
+
out = text
|
278 |
+
|
279 |
+
if exists(img):
|
280 |
+
assert img.shape[1] == image_seq_len, f'input image must have the correct image size {image_seq_len}'
|
281 |
+
|
282 |
+
num_img_tokens = default(num_init_img_tokens, int(0.4375 * image_seq_len)) # OpenAI used 14 * 32 initial tokens to prime
|
283 |
+
assert num_img_tokens < image_seq_len, 'number of initial image tokens for priming must be less than the total image token sequence length'
|
284 |
+
|
285 |
+
trunc_img = img[:, :num_img_tokens]
|
286 |
+
out = torch.cat((out, trunc_img), dim = -1)
|
287 |
+
|
288 |
+
for cur_len in range(out.shape[1], total_len):
|
289 |
+
is_image = cur_len >= text_seq_len
|
290 |
+
|
291 |
+
text, image = out[:, :text_seq_len], out[:, text_seq_len:]
|
292 |
+
|
293 |
+
logits = self(text, image)[:, -1, :]
|
294 |
+
|
295 |
+
filtered_logits = top_k(logits, thres = filter_thres)
|
296 |
+
probs = F.softmax(filtered_logits / temperature, dim = -1)
|
297 |
+
sample = torch.multinomial(probs, 1)
|
298 |
+
|
299 |
+
sample -= (num_text_tokens if is_image else 0) # offset sampled token if it is an image token, since logit space is composed of text and then image tokens
|
300 |
+
out = torch.cat((out, sample), dim=-1)
|
301 |
+
|
302 |
+
text_seq = out[:, :text_seq_len]
|
303 |
+
img_seq = out[:, -image_seq_len:]
|
304 |
+
|
305 |
+
scores = None
|
306 |
+
if exists(clip):
|
307 |
+
scores = clip(text_seq, img_seq, return_loss = False)
|
308 |
+
|
309 |
+
img_seq = repeat(img_seq, 'b p -> b p c', c=3)
|
310 |
+
decoder = repeat(decoder, 'p c -> b p c', b=batch)
|
311 |
+
images = torch.gather(decoder, 1, img_seq)
|
312 |
+
images = rearrange(images, 'b (h w) c-> b c h w', h=24, w =24)
|
313 |
+
images = images.float()
|
314 |
+
|
315 |
+
return images, scores
|
316 |
+
|
317 |
+
def forward(
|
318 |
+
self,
|
319 |
+
text,
|
320 |
+
image = None,
|
321 |
+
return_loss = False
|
322 |
+
):
|
323 |
+
assert text.shape[-1] == self.text_seq_len, f'the length {text.shape[-1]} of the text tokens you passed in does not have the correct length ({self.text_seq_len})'
|
324 |
+
device, total_seq_len = text.device, self.total_seq_len
|
325 |
+
|
326 |
+
text_range = torch.arange(self.text_seq_len, device = device) + (self.num_text_tokens - self.text_seq_len)
|
327 |
+
text = torch.where(text == 0, text_range, text)
|
328 |
+
|
329 |
+
text = F.pad(text, (1, 0), value = 0) # add <bos>
|
330 |
+
|
331 |
+
tokens = self.text_emb(text)
|
332 |
+
tokens += self.text_pos_emb(torch.arange(text.shape[1], device = device))
|
333 |
+
|
334 |
+
seq_len = tokens.shape[1]
|
335 |
+
|
336 |
+
image_len = image.shape[1]
|
337 |
+
image_emb = self.image_emb(image)
|
338 |
+
# image_emb += self.image_pos_emb(torch.arange(image_len, device = device))
|
339 |
+
image_emb += self.image_pos_emb[:, :image_len, :]
|
340 |
+
|
341 |
+
# image_emb += self.image_pos_emb(image_emb)
|
342 |
+
|
343 |
+
tokens = torch.cat((tokens, image_emb), dim = 1)
|
344 |
+
|
345 |
+
seq_len += image_len
|
346 |
+
|
347 |
+
# when training, if the length exceeds the total text + image length
|
348 |
+
# remove the last token, since it needs not to be trained
|
349 |
+
|
350 |
+
if tokens.shape[1] > total_seq_len:
|
351 |
+
seq_len -= 1
|
352 |
+
tokens = tokens[:, :-1]
|
353 |
+
|
354 |
+
out = self.transformer(tokens)
|
355 |
+
logits = self.to_logits(out)
|
356 |
+
|
357 |
+
# mask logits to make sure text predicts text (except last token), and image predicts image
|
358 |
+
|
359 |
+
logits_mask = self.logits_mask[:, :seq_len]
|
360 |
+
max_neg_value = -torch.finfo(logits.dtype).max
|
361 |
+
logits.masked_fill_(logits_mask, max_neg_value)
|
362 |
+
|
363 |
+
if not return_loss:
|
364 |
+
return logits
|
365 |
+
|
366 |
+
assert exists(image), 'when training, image must be supplied'
|
367 |
+
|
368 |
+
offsetted_image = image + self.num_text_tokens
|
369 |
+
labels = torch.cat((text[:, 1:], offsetted_image), dim = 1)
|
370 |
+
|
371 |
+
logits = rearrange(logits, 'b n c -> b c n')
|
372 |
+
|
373 |
+
loss_text = F.cross_entropy(logits[:, :, :self.text_seq_len], labels[:, :self.text_seq_len])
|
374 |
+
loss_img = F.cross_entropy(logits[:, :, self.text_seq_len:], labels[:, self.text_seq_len:])
|
375 |
+
|
376 |
+
loss = (self.loss_txt_weight * loss_text + self.loss_img_weight * loss_img) / (self.loss_img_weight + self.loss_txt_weight)
|
377 |
+
return loss, loss_text, loss_img
|
text2punks/tokenizer.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import html
|
3 |
+
import ftfy
|
4 |
+
import regex as re
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from functools import lru_cache
|
10 |
+
|
11 |
+
import youtokentome as yttm
|
12 |
+
from tokenizers import Tokenizer
|
13 |
+
from tokenizers.processors import ByteLevel
|
14 |
+
|
15 |
+
|
16 |
+
# OpenAI simple tokenizer
|
17 |
+
|
18 |
+
@lru_cache()
|
19 |
+
def default_bpe(bpe_path = "data/bpe_simple_vocab_16e6.txt"):
|
20 |
+
return os.path.join(os.path.dirname(os.path.abspath(__file__)), bpe_path)
|
21 |
+
|
22 |
+
@lru_cache()
|
23 |
+
def bytes_to_unicode():
|
24 |
+
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
|
25 |
+
cs = bs[:]
|
26 |
+
n = 0
|
27 |
+
for b in range(2 ** 8):
|
28 |
+
if b not in bs:
|
29 |
+
bs.append(b)
|
30 |
+
cs.append(2 ** 8 + n)
|
31 |
+
n += 1
|
32 |
+
cs = [chr(n) for n in cs]
|
33 |
+
return dict(zip(bs, cs))
|
34 |
+
|
35 |
+
def get_pairs(word):
|
36 |
+
pairs = set()
|
37 |
+
prev_char = word[0]
|
38 |
+
for char in word[1:]:
|
39 |
+
pairs.add((prev_char, char))
|
40 |
+
prev_char = char
|
41 |
+
return pairs
|
42 |
+
|
43 |
+
def basic_clean(text):
|
44 |
+
text = ftfy.fix_text(text)
|
45 |
+
text = html.unescape(html.unescape(text))
|
46 |
+
return text.strip()
|
47 |
+
|
48 |
+
def whitespace_clean(text):
|
49 |
+
text = re.sub(r'\s+', ' ', text)
|
50 |
+
text = text.strip()
|
51 |
+
return text
|
52 |
+
|
53 |
+
|
54 |
+
class SimpleTokenizer(object):
|
55 |
+
def __init__(self, bpe_path = default_bpe()):
|
56 |
+
self.byte_encoder = bytes_to_unicode()
|
57 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
58 |
+
merges = Path(bpe_path).read_text(encoding='utf8').split('\n')
|
59 |
+
merges = merges[1:49152 - 256 - 2 + 1]
|
60 |
+
merges = [tuple(merge.split()) for merge in merges]
|
61 |
+
vocab = list(bytes_to_unicode().values())
|
62 |
+
vocab = vocab + [v + '</w>' for v in vocab]
|
63 |
+
for merge in merges:
|
64 |
+
vocab.append(''.join(merge))
|
65 |
+
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
66 |
+
|
67 |
+
self.vocab_size = 49408
|
68 |
+
|
69 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
70 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
71 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
72 |
+
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
73 |
+
self.pat = re.compile(
|
74 |
+
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
|
75 |
+
re.IGNORECASE)
|
76 |
+
|
77 |
+
def bpe(self, token):
|
78 |
+
if token in self.cache:
|
79 |
+
return self.cache[token]
|
80 |
+
word = tuple(token[:-1]) + (token[-1] + '</w>',)
|
81 |
+
pairs = get_pairs(word)
|
82 |
+
|
83 |
+
if not pairs:
|
84 |
+
return token + '</w>'
|
85 |
+
|
86 |
+
while True:
|
87 |
+
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
88 |
+
if bigram not in self.bpe_ranks:
|
89 |
+
break
|
90 |
+
first, second = bigram
|
91 |
+
new_word = []
|
92 |
+
i = 0
|
93 |
+
while i < len(word):
|
94 |
+
try:
|
95 |
+
j = word.index(first, i)
|
96 |
+
new_word.extend(word[i:j])
|
97 |
+
i = j
|
98 |
+
except:
|
99 |
+
new_word.extend(word[i:])
|
100 |
+
break
|
101 |
+
|
102 |
+
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
103 |
+
new_word.append(first + second)
|
104 |
+
i += 2
|
105 |
+
else:
|
106 |
+
new_word.append(word[i])
|
107 |
+
i += 1
|
108 |
+
new_word = tuple(new_word)
|
109 |
+
word = new_word
|
110 |
+
if len(word) == 1:
|
111 |
+
break
|
112 |
+
else:
|
113 |
+
pairs = get_pairs(word)
|
114 |
+
word = ' '.join(word)
|
115 |
+
self.cache[token] = word
|
116 |
+
return word
|
117 |
+
|
118 |
+
def encode(self, text):
|
119 |
+
bpe_tokens = []
|
120 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
121 |
+
for token in re.findall(self.pat, text):
|
122 |
+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
123 |
+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
124 |
+
return bpe_tokens
|
125 |
+
|
126 |
+
def decode(self, tokens, remove_start_end = True):
|
127 |
+
if torch.is_tensor(tokens):
|
128 |
+
tokens = tokens.tolist()
|
129 |
+
|
130 |
+
if remove_start_end:
|
131 |
+
tokens = [token for token in tokens if token not in (49406, 40407, 0)]
|
132 |
+
text = ''.join([self.decoder[token] for token in tokens])
|
133 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
134 |
+
return text
|
135 |
+
|
136 |
+
def tokenize(self, texts, context_length = 256, truncate_text = False):
|
137 |
+
if isinstance(texts, str):
|
138 |
+
texts = [texts]
|
139 |
+
|
140 |
+
all_tokens = [self.encode(text) for text in texts]
|
141 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
142 |
+
|
143 |
+
for i, tokens in enumerate(all_tokens):
|
144 |
+
if len(tokens) > context_length:
|
145 |
+
if truncate_text:
|
146 |
+
tokens = tokens[:context_length]
|
147 |
+
else:
|
148 |
+
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
|
149 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
150 |
+
|
151 |
+
return result
|
152 |
+
|
153 |
+
# txt_tokenizer = SimpleTokenizer()
|
154 |
+
|
155 |
+
# huggingface tokenizer
|
156 |
+
|
157 |
+
class HugTokenizer:
|
158 |
+
def __init__(self, bpe_path):
|
159 |
+
bpe_path = Path(default_bpe(bpe_path = bpe_path))
|
160 |
+
assert bpe_path.exists(), f'BPE json path {str(bpe_path)} does not exist'
|
161 |
+
tokenizer = Tokenizer.from_file(str(bpe_path))
|
162 |
+
tokenizer.post_processor = ByteLevel(trim_offsets = True)
|
163 |
+
self.tokenizer = tokenizer
|
164 |
+
self.vocab_size = tokenizer.get_vocab_size()
|
165 |
+
|
166 |
+
def decode(self, tokens):
|
167 |
+
if torch.is_tensor(tokens):
|
168 |
+
tokens = tokens.tolist()
|
169 |
+
|
170 |
+
tokens = [token for token in tokens if token not in (0,)]
|
171 |
+
return self.tokenizer.decode(tokens, skip_special_tokens = True)
|
172 |
+
|
173 |
+
def encode(self, text):
|
174 |
+
return self.tokenizer.encode(text).ids
|
175 |
+
|
176 |
+
def tokenize(self, texts, context_length = 256, truncate_text = False):
|
177 |
+
if isinstance(texts, str):
|
178 |
+
texts = [texts]
|
179 |
+
|
180 |
+
all_tokens = [self.encode(text) for text in texts]
|
181 |
+
|
182 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
183 |
+
for i, tokens in enumerate(all_tokens):
|
184 |
+
if len(tokens) > context_length:
|
185 |
+
if truncate_text:
|
186 |
+
tokens = tokens[:context_length]
|
187 |
+
else:
|
188 |
+
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
|
189 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
190 |
+
|
191 |
+
return result
|
192 |
+
|
193 |
+
txt_tokenizer = HugTokenizer(bpe_path = "data/byte-level-bpe_4k.tokenizer.json")
|
194 |
+
|
195 |
+
# yttm tokenizer
|
196 |
+
|
197 |
+
class YttmTokenizer:
|
198 |
+
def __init__(self, bpe_path = None):
|
199 |
+
bpe_path = Path(default_bpe(bpe_path = bpe_path))
|
200 |
+
assert bpe_path.exists(), f'BPE json path {str(bpe_path)} does not exist'
|
201 |
+
|
202 |
+
tokenizer = yttm.BPE(model = str(bpe_path))
|
203 |
+
self.tokenizer = tokenizer
|
204 |
+
self.vocab_size = tokenizer.vocab_size()
|
205 |
+
|
206 |
+
def decode(self, tokens):
|
207 |
+
if torch.is_tensor(tokens):
|
208 |
+
tokens = tokens.tolist()
|
209 |
+
|
210 |
+
return self.tokenizer.decode(tokens, ignore_ids = [0])
|
211 |
+
|
212 |
+
def encode(self, texts):
|
213 |
+
encoded = self.tokenizer.encode(texts, output_type = yttm.OutputType.ID)
|
214 |
+
return list(map(torch.tensor, encoded))
|
215 |
+
|
216 |
+
def tokenize(self, texts, context_length = 256, truncate_text = False):
|
217 |
+
if isinstance(texts, str):
|
218 |
+
texts = [texts]
|
219 |
+
|
220 |
+
all_tokens = self.encode(texts)
|
221 |
+
|
222 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
223 |
+
for i, tokens in enumerate(all_tokens):
|
224 |
+
if len(tokens) > context_length:
|
225 |
+
if truncate_text:
|
226 |
+
tokens = tokens[:context_length]
|
227 |
+
else:
|
228 |
+
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
|
229 |
+
result[i, :len(tokens)] = tokens.detach().clone()
|
230 |
+
|
231 |
+
return result
|
232 |
+
|
233 |
+
# txt_tokenizer = YttmTokenizer(bpe_path = "data/byte-level-bpe.tokenizer.json")
|
text2punks/transformer.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
from itertools import islice, cycle
|
3 |
+
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
from text2punks.attention import Attention, SparseAxialCausalAttention
|
7 |
+
|
8 |
+
# helpers
|
9 |
+
|
10 |
+
def exists(val):
|
11 |
+
return val is not None
|
12 |
+
|
13 |
+
def default(val, d):
|
14 |
+
return val if exists(val) else d
|
15 |
+
|
16 |
+
def cast_tuple(val, depth = 1):
|
17 |
+
if isinstance(val, list):
|
18 |
+
val = tuple(val)
|
19 |
+
return val if isinstance(val, tuple) else (val,) * depth
|
20 |
+
|
21 |
+
# classes
|
22 |
+
|
23 |
+
class SequentialSequence(nn.Module):
|
24 |
+
def __init__(self, layers):
|
25 |
+
super().__init__()
|
26 |
+
self.layers = layers
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
for (f, g) in list(self.layers):
|
30 |
+
x = x + f(x)
|
31 |
+
x = x + g(x)
|
32 |
+
return x
|
33 |
+
|
34 |
+
class PreNorm(nn.Module):
|
35 |
+
def __init__(self, dim, fn):
|
36 |
+
super().__init__()
|
37 |
+
self.norm = nn.LayerNorm(dim)
|
38 |
+
self.fn = fn
|
39 |
+
|
40 |
+
def forward(self, x, **kwargs):
|
41 |
+
return self.fn(self.norm(x), **kwargs)
|
42 |
+
|
43 |
+
class FeedForward(nn.Module):
|
44 |
+
def __init__(self, dim, dropout = 0.):
|
45 |
+
super().__init__()
|
46 |
+
self.net = nn.Sequential(
|
47 |
+
nn.Linear(dim, dim * 4),
|
48 |
+
nn.GELU(),
|
49 |
+
nn.Dropout(dropout),
|
50 |
+
nn.Linear(dim * 4, dim)
|
51 |
+
)
|
52 |
+
# the order of dropout nn.Linear(4 * n_embd, n_embd) vs nn.Dropout(resid_pdrop)
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
return self.net(x)
|
56 |
+
|
57 |
+
|
58 |
+
class Transformer(nn.Module):
|
59 |
+
def __init__(
|
60 |
+
self,
|
61 |
+
*,
|
62 |
+
dim,
|
63 |
+
depth,
|
64 |
+
seq_len,
|
65 |
+
causal = True,
|
66 |
+
heads = 8,
|
67 |
+
dim_head = 64,
|
68 |
+
attn_dropout = 0.,
|
69 |
+
resid_dropout = 0.,
|
70 |
+
embd_dropout = 0.,
|
71 |
+
ff_dropout = 0.,
|
72 |
+
image_size = 24,
|
73 |
+
attn_types = None,
|
74 |
+
):
|
75 |
+
super().__init__()
|
76 |
+
layers = nn.ModuleList([])
|
77 |
+
|
78 |
+
attn_types = default(attn_types, ('full',))
|
79 |
+
attn_types = cast_tuple(attn_types)
|
80 |
+
attn_type_layer = islice(cycle(attn_types), depth)
|
81 |
+
|
82 |
+
for attn_type in attn_type_layer:
|
83 |
+
if attn_type == 'full':
|
84 |
+
attn_class = partial(Attention, causal = causal)
|
85 |
+
elif attn_type == 'axial_row':
|
86 |
+
attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 0, image_size = image_size)
|
87 |
+
elif attn_type == 'axial_col':
|
88 |
+
attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 1, image_size = image_size)
|
89 |
+
else:
|
90 |
+
raise ValueError(f'attention type "{attn_type}" is not valid')
|
91 |
+
|
92 |
+
attn = attn_class(dim, seq_len = seq_len, heads = heads, dim_head = dim_head, attn_dropout = attn_dropout, resid_dropout = resid_dropout)
|
93 |
+
|
94 |
+
layers.append(nn.ModuleList([
|
95 |
+
PreNorm(dim, attn),
|
96 |
+
PreNorm(dim, FeedForward(dim, dropout = ff_dropout))
|
97 |
+
]))
|
98 |
+
|
99 |
+
# full attention in the last layer
|
100 |
+
|
101 |
+
attn_class = partial(Attention, causal = causal)
|
102 |
+
attn = attn_class(dim, seq_len = seq_len, heads = heads, dim_head = dim_head, attn_dropout = attn_dropout, resid_dropout = resid_dropout)
|
103 |
+
|
104 |
+
layers.append(nn.ModuleList([
|
105 |
+
PreNorm(dim, attn),
|
106 |
+
PreNorm(dim, FeedForward(dim, dropout = ff_dropout))
|
107 |
+
]))
|
108 |
+
|
109 |
+
self.layers = SequentialSequence(layers)
|
110 |
+
self.embd_drop = nn.Dropout(embd_dropout)
|
111 |
+
|
112 |
+
def forward(self, x):
|
113 |
+
x = self.embd_drop(x)
|
114 |
+
return self.layers(x)
|
115 |
+
|
text2punks/utils.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# os
|
2 |
+
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
# torch
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torchvision.transforms.functional as F
|
9 |
+
from einops import repeat
|
10 |
+
|
11 |
+
# Text2Punks and Tokenizer
|
12 |
+
|
13 |
+
from text2punks.text2punk import Text2Punks, CLIP
|
14 |
+
from text2punks.tokenizer import txt_tokenizer
|
15 |
+
|
16 |
+
# select device
|
17 |
+
|
18 |
+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
19 |
+
|
20 |
+
# load decoder
|
21 |
+
|
22 |
+
codebook = torch.load('./text2punks/data/codebook.pt')
|
23 |
+
|
24 |
+
# helper fns
|
25 |
+
|
26 |
+
def exists(val):
|
27 |
+
return val is not None
|
28 |
+
|
29 |
+
|
30 |
+
def to_pil_image(image_tensor):
|
31 |
+
return F.to_pil_image(image_tensor)
|
32 |
+
|
33 |
+
|
34 |
+
def model_loader(text2punk_path, clip_path):
|
35 |
+
# load pre-trained TEXT2PUNKS model
|
36 |
+
|
37 |
+
text2punk_path = Path(text2punk_path)
|
38 |
+
assert text2punk_path.exists(), 'trained Text2Punks must exist'
|
39 |
+
|
40 |
+
load_obj = torch.load(str(text2punk_path), map_location=torch.device(device))
|
41 |
+
text2punks_params, weights = load_obj.pop('hparams'), load_obj.pop('weights')
|
42 |
+
|
43 |
+
text2punk = Text2Punks(**text2punks_params).to(device)
|
44 |
+
text2punk.load_state_dict(weights)
|
45 |
+
|
46 |
+
# load pre-trained CLIP model
|
47 |
+
|
48 |
+
clip_path = Path(clip_path)
|
49 |
+
assert clip_path.exists(), 'trained CLIP must exist'
|
50 |
+
|
51 |
+
load_obj = torch.load(str(clip_path), map_location=torch.device(device))
|
52 |
+
clip_params, weights = load_obj.pop('hparams'), load_obj.pop('weights')
|
53 |
+
|
54 |
+
clip = CLIP(**clip_params).to(device)
|
55 |
+
clip.load_state_dict(weights)
|
56 |
+
|
57 |
+
return text2punk, clip
|
58 |
+
|
59 |
+
|
60 |
+
def generate_image(prompt_text, top_k, temperature, num_images, batch_size, top_prediction, text2punk_model, clip_model, codebook=codebook):
|
61 |
+
text = txt_tokenizer.tokenize(prompt_text, text2punk_model.text_seq_len, truncate_text=True).to(device)
|
62 |
+
|
63 |
+
text = repeat(text, '() n -> b n', b = num_images)
|
64 |
+
|
65 |
+
img_outputs = []
|
66 |
+
score_outputs = []
|
67 |
+
|
68 |
+
for text_chunk in text.split(batch_size):
|
69 |
+
images, scores = text2punk_model.generate_images(text_chunk, codebook.to(device), clip = clip_model, filter_thres = top_k, temperature = temperature)
|
70 |
+
img_outputs.append(images)
|
71 |
+
score_outputs.append(scores)
|
72 |
+
|
73 |
+
img_outputs = torch.cat(img_outputs)
|
74 |
+
score_outputs = torch.cat(score_outputs)
|
75 |
+
|
76 |
+
similarity = score_outputs.softmax(dim=-1)
|
77 |
+
values, indices = similarity.topk(top_prediction)
|
78 |
+
|
79 |
+
img_outputs = img_outputs[indices]
|
80 |
+
score_outputs = score_outputs[indices]
|
81 |
+
|
82 |
+
return img_outputs, score_outputs
|