NoelVouitsis commited on
Commit
cc9204b
1 Parent(s): 1529364
README.md CHANGED
@@ -2,9 +2,9 @@
2
  title: TR0N
3
  emoji: 🐨
4
  colorFrom: gray
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 3.27.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
2
  title: TR0N
3
  emoji: 🐨
4
  colorFrom: gray
5
+ colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 3.24.1
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
app.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import gradio as gr
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch.optim import Adam
7
+ from torchvision.transforms import transforms as T
8
+ import clip
9
+ from tr0n.config import parse_args
10
+ from tr0n.modules.models.model_stylegan import Model
11
+ from tr0n.modules.models.loss import AugCosineSimLatent
12
+ from tr0n.modules.optimizers.sgld import SGLD
13
+ from bad_words import bad_words
14
+
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ model_modes = {
17
+ "text": {
18
+ "checkpoint": "https://huggingface.co/Layer6/tr0n-stylegan2-clip/resolve/main/tr0n-stylegan2-clip-text.pth",
19
+ },
20
+ "image": {
21
+ "checkpoint": "https://huggingface.co/Layer6/tr0n-stylegan2-clip/resolve/main/tr0n-stylegan2-clip-image.pth",
22
+ }
23
+ }
24
+
25
+ os.environ['TOKENIZERS_PARALLELISM'] = "false"
26
+
27
+
28
+ # set config params
29
+ config = parse_args(is_demo=True)
30
+ config_vars = vars(config)
31
+ config_vars["stylegan_gen"] = "sg2-ffhq-1024"
32
+ config_vars["with_gmm"] = True
33
+ config_vars["num_mixtures"] = 10
34
+
35
+
36
+ model = Model(config, device, None)
37
+ model.to(device)
38
+ model.eval()
39
+ for p in model.translator.parameters():
40
+ p.requires_grad = False
41
+ loss = AugCosineSimLatent()
42
+
43
+
44
+ transforms_image = T.Compose([
45
+ T.Resize(224, interpolation=T.InterpolationMode.BICUBIC),
46
+ T.CenterCrop(224),
47
+ T.ToTensor(),
48
+ T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
49
+ ])
50
+
51
+
52
+ checkpoint_text = torch.hub.load_state_dict_from_url(model_modes["text"]["checkpoint"], map_location="cpu")
53
+ translator_state_dict_text = checkpoint_text['translator_state_dict']
54
+ checkpoint_image = torch.hub.load_state_dict_from_url(model_modes["image"]["checkpoint"], map_location="cpu")
55
+ translator_state_dict_image = checkpoint_image['translator_state_dict']
56
+
57
+ # default
58
+ model.translator.load_state_dict(translator_state_dict_text)
59
+
60
+
61
+ css = """
62
+ a {
63
+ display: inline-block;
64
+ color: black !important;
65
+ text-decoration: none !important;
66
+ }
67
+ #image-gen {
68
+ height: 256px;
69
+ width: 256px;
70
+ margin-left: auto;
71
+ margin-right: auto;
72
+ }
73
+ """
74
+
75
+
76
+ def _slerp(val, low, high):
77
+ low_norm = low / torch.norm(low, dim=1, keepdim=True)
78
+ high_norm = high / torch.norm(high, dim=1, keepdim=True)
79
+ omega = torch.acos((low_norm*high_norm).sum(1))
80
+ so = torch.sin(omega)
81
+ res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
82
+ return res
83
+
84
+
85
+ def model_mode_text_select():
86
+ model.translator.load_state_dict(translator_state_dict_text)
87
+
88
+
89
+ def model_mode_image_select():
90
+ model.translator.load_state_dict(translator_state_dict_image)
91
+
92
+
93
+ def text_to_face_generate(text):
94
+ if text == "":
95
+ raise gr.Error("You need to provide to provide a prompt.")
96
+
97
+ for word in bad_words:
98
+ if re.search(rf"\b{word}\b", text):
99
+ raise gr.Error("Unsafe content found. Please try again with a different prompt.")
100
+
101
+ text_tok = clip.tokenize([text], truncate=True).to(device)
102
+
103
+ # initialize optimization from the translator's output
104
+ with torch.no_grad():
105
+ target_clip_latent, w_mixture_logits, w_means = model(x=text_tok, x_type='text', return_after_translator=True, no_sample=True)
106
+ pi = w_mixture_logits.unsqueeze(-1).repeat(1, 1, w_means.shape[-1]) # 1 x num_mixtures x w_dim
107
+ w = w_means # 1 x num_mixtures x w_dim
108
+
109
+ w.requires_grad = True
110
+ pi.requires_grad = True
111
+
112
+ optimizer_w = SGLD((w,), lr=1e-1, momentum=0.99, noise_std=0.01, device=device)
113
+ optimizer_pi = Adam((pi,), lr=5e-3)
114
+
115
+ # optimization
116
+ for _ in range(100):
117
+ soft_pi = F.softmax(pi, dim=1)
118
+ w_prime = soft_pi * w
119
+ w_prime = w_prime.sum(dim=1)
120
+
121
+ _, _, pred_clip_latent, _, _ = model(x=w_prime, x_type='gan_latent', times_augment_pred_image=50)
122
+
123
+ l = loss(target_clip_latent, pred_clip_latent)
124
+ l.backward()
125
+ torch.nn.utils.clip_grad_norm_((w,), 1.)
126
+ torch.nn.utils.clip_grad_norm_((pi,), 1.)
127
+ optimizer_w.step()
128
+ optimizer_pi.step()
129
+ optimizer_w.zero_grad()
130
+ optimizer_pi.zero_grad()
131
+
132
+ # generate final image
133
+ with torch.no_grad():
134
+ soft_pi = F.softmax(pi, dim=1)
135
+ w_prime = soft_pi * w
136
+ w_prime = w_prime.sum(dim=1)
137
+
138
+ _, _, _, _, pred_image_raw = model(x=w_prime, x_type='gan_latent')
139
+
140
+ pred_image = ((pred_image_raw[0]+1.)/2.).cpu()
141
+ return T.ToPILImage()(pred_image)
142
+
143
+
144
+ def face_to_face_interpolate(image1, image2, interp_lambda=0.5):
145
+ if image1 is None or image2 is None:
146
+ raise gr.Error("You need to provide two images as input.")
147
+
148
+ image1_pt = transforms_image(image1).to(device)
149
+ image2_pt = transforms_image(image2).to(device)
150
+
151
+ # initialize optimization from the translator's output
152
+ with torch.no_grad():
153
+ images_pt = torch.stack([image1_pt, image2_pt])
154
+ target_clip_latents = model.clip.encode_image(images_pt).detach().float()
155
+ target_clip_latent = _slerp(interp_lambda, target_clip_latents[0].unsqueeze(0), target_clip_latents[1].unsqueeze(0))
156
+ _, _, w = model(x=target_clip_latent, x_type='clip_latent', return_after_translator=True)
157
+
158
+ w.requires_grad = True
159
+
160
+ optimizer_w = SGLD((w,), lr=1e-1, momentum=0.99, noise_std=0.01, device=device)
161
+
162
+ # optimization
163
+ for _ in range(100):
164
+ _, _, pred_clip_latent, _, _ = model(x=w, x_type='gan_latent', times_augment_pred_image=50)
165
+
166
+ l = loss(target_clip_latent, pred_clip_latent)
167
+ l.backward()
168
+ torch.nn.utils.clip_grad_norm_((w,), 1.)
169
+ optimizer_w.step()
170
+ optimizer_w.zero_grad()
171
+
172
+ # generate final image
173
+ with torch.no_grad():
174
+ _, _, _, _, pred_image_raw = model(x=w, x_type='gan_latent')
175
+
176
+ pred_image = ((pred_image_raw[0]+1.)/2.).cpu()
177
+ return T.ToPILImage()(pred_image)
178
+
179
+
180
+ examples_text = [
181
+ "Muhammad Ali",
182
+ "Tinker Bell",
183
+ "A man with glasses, long black hair with sideburns and a goatee",
184
+ "A child with blue eyes and straight brown hair in the sunshine",
185
+ "A hairdresser",
186
+ "A young boy with glasses and an angry face",
187
+ "Denzel Washington",
188
+ "A portrait of Angela Merkel",
189
+ "President Emmanuel Macron",
190
+ "Prime Minister Shinzo Abe"
191
+ ]
192
+
193
+ examples_image = [
194
+ ["./examples/example_1_1.jpg", "./examples/example_1_2.jpg"],
195
+ ["./examples/example_2_1.jpg", "./examples/example_2_2.jpg"],
196
+ ["./examples/example_3_1.jpg", "./examples/example_3_2.jpg"],
197
+ ["./examples/example_4_1.jpg", "./examples/example_4_2.jpg"],
198
+ ]
199
+
200
+
201
+ with gr.Blocks(css=css) as demo:
202
+ gr.Markdown("<h1><center>TR0N Face Generation Demo</center></h1>")
203
+ gr.Markdown("<h3><center><a href='https://layer6.ai/'>by Layer 6 AI</a></center></h3>")
204
+ gr.Markdown("""<p align='middle'>
205
+ <a href='https://arxiv.org/abs/2304.13742'><img src='https://img.shields.io/badge/arXiv-2304.13742-b31b1b.svg' /></a>
206
+ <a href='https://github.com/layer6ai-labs/tr0n'><img src='https://badgen.net/badge/icon/github?icon=github&label' /></a>
207
+ </p>""")
208
+ gr.Markdown("We introduce TR0N, a simple and efficient method to add any type of conditioning to pre-trained generative models. For this demo, we add two types of conditioning to a StyleGAN2 model pre-trained on images of human faces. First, we add text-conditioning to turn StyleGAN2 into a text-to-face model. Second, we add image semantic conditioning to StyleGAN2 to enable face-to-face interpolation. For more details and results on many other generative models, please refer to our paper linked above.")
209
+
210
+ with gr.Tab("Text-to-face generation") as text_to_face_generation_demo:
211
+ text_to_face_generation_input = gr.Textbox(label="Enter your prompt", placeholder="e.g. A man with a beard and glasses", max_lines=1)
212
+ text_to_face_generation_button = gr.Button("Generate")
213
+ text_to_face_generation_output = gr.Image(label="Generated image", elem_id="image-gen")
214
+ text_to_face_generation_examples = gr.Examples(examples=examples_text, fn=text_to_face_generate, inputs=text_to_face_generation_input, outputs=text_to_face_generation_output)
215
+
216
+ with gr.Tab("Face-to-face interpolation") as face_to_face_interpolation_demo:
217
+ gr.Markdown("We note that interpolations are not expected to recover the given images, even when the coefficient is 0 or 1.")
218
+ with gr.Row():
219
+ face_to_face_interpolation_input1 = gr.Image(label="Image 1", type="pil")
220
+ face_to_face_interpolation_input2 = gr.Image(label="Image 2", type="pil")
221
+ face_to_face_interpolation_lambda = gr.Slider(label="Interpolation coefficient", minimum=0, maximum=1, value=0.5, step=0.01)
222
+ face_to_face_interpolation_button = gr.Button("Interpolate")
223
+ face_to_face_interpolation_output = gr.Image(label="Interpolated image", elem_id="image-gen")
224
+ face_to_face_interpolation_examples = gr.Examples(examples=examples_image, fn=face_to_face_interpolate, inputs=[face_to_face_interpolation_input1, face_to_face_interpolation_input2, face_to_face_interpolation_lambda], outputs=face_to_face_interpolation_output)
225
+
226
+ text_to_face_generation_demo.select(fn=model_mode_text_select)
227
+ text_to_face_generation_input.submit(fn=text_to_face_generate, inputs=text_to_face_generation_input, outputs=text_to_face_generation_output)
228
+ text_to_face_generation_button.click(fn=text_to_face_generate, inputs=text_to_face_generation_input, outputs=text_to_face_generation_output)
229
+
230
+ face_to_face_interpolation_demo.select(fn=model_mode_image_select)
231
+ face_to_face_interpolation_button.click(fn=face_to_face_interpolate, inputs=[face_to_face_interpolation_input1, face_to_face_interpolation_input2, face_to_face_interpolation_lambda], outputs=face_to_face_interpolation_output)
232
+
233
+
234
+ demo.queue()
235
+ demo.launch()
bad_words.py ADDED
@@ -0,0 +1 @@
 
 
1
+ bad_words = ["4r5e", "5h1t", "5hit", "a55", "anal", "anus", "ar5e", "arrse", "arse", "ass", "ass-fucker", "asses", "assfucker", "assfukka", "asshole", "assholes", "asswhole", "a_s_s", "b!tch", "b00bs", "b17ch", "b1tch", "ballbag", "balls", "ballsack", "bastard", "beastial", "beastiality", "bellend", "bestial", "bestiality", "bi+ch", "biatch", "bitch", "bitcher", "bitchers", "bitches", "bitchin", "bitching", "bloody", "blow job", "blowjob", "blowjobs", "boiolas", "bollock", "bollok", "boner", "boob", "boobs", "booobs", "boooobs", "booooobs", "booooooobs", "breasts", "buceta", "bugger", "bum", "bunny fucker", "butt", "butthole", "buttmuch", "buttplug", "c0ck", "c0cksucker", "carpet muncher", "cawk", "chink", "cipa", "cl1t", "clit", "clitoris", "clits", "cnut", "cock", "cock-sucker", "cockface", "cockhead", "cockmunch", "cockmuncher", "cocks", "cocksuck", "cocksucked", "cocksucker", "cocksucking", "cocksucks", "cocksuka", "cocksukka", "cok", "cokmuncher", "coksucka", "coon", "cox", "crap", "cum", "cummer", "cumming", "cums", "cumshot", "cunilingus", "cunillingus", "cunnilingus", "cunt", "cuntlick", "cuntlicker", "cuntlicking", "cunts", "cyalis", "cyberfuc", "cyberfuck", "cyberfucked", "cyberfucker", "cyberfuckers", "cyberfucking", "d1ck", "damn", "dick", "dickhead", "dildo", "dildos", "dink", "dinks", "dirsa", "dlck", "dog-fucker", "doggin", "dogging", "donkeyribber", "doosh", "duche", "dyke", "ejaculate", "ejaculated", "ejaculates", "ejaculating", "ejaculatings", "ejaculation", "ejakulate", "f u c k", "f u c k e r", "f4nny", "fag", "fagging", "faggitt", "faggot", "faggs", "fagot", "fagots", "fags", "fanny", "fannyflaps", "fannyfucker", "fanyy", "fatass", "fcuk", "fcuker", "fcuking", "feck", "fecker", "felching", "fellate", "fellatio", "fingerfuck", "fingerfucked", "fingerfucker", "fingerfuckers", "fingerfucking", "fingerfucks", "fistfuck", "fistfucked", "fistfucker", "fistfuckers", "fistfucking", "fistfuckings", "fistfucks", "flange", "fook", "fooker", "fuck", "fucka", "fucked", "fucker", "fuckers", "fuckhead", "fuckheads", "fuckin", "fucking", "fuckings", "fuckingshitmotherfucker", "fuckme", "fucks", "fuckwhit", "fuckwit", "fudge packer", "fudgepacker", "fuk", "fuker", "fukker", "fukkin", "fuks", "fukwhit", "fukwit", "fux", "fux0r", "f_u_c_k", "gangbang", "gangbanged", "gangbangs", "gaylord", "gaysex", "goatse", "God", "god-dam", "god-damned", "goddamn", "goddamned", "hardcoresex", "hell", "heshe", "hoar", "hoare", "hoer", "homo", "hore", "horniest", "horny", "hotsex", "jack-off", "jackoff", "jap", "jerk-off", "jism", "jiz", "jizm", "jizz", "kawk", "knob", "knobead", "knobed", "knobend", "knobhead", "knobjocky", "knobjokey", "kock", "kondum", "kondums", "kum", "kummer", "kumming", "kums", "kunilingus", "l3i+ch", "l3itch", "labia", "lust", "lusting", "m0f0", "m0fo", "m45terbate", "ma5terb8", "ma5terbate", "masochist", "master-bate", "masterb8", "masterbat*", "masterbat3", "masterbate", "masterbation", "masterbations", "masturbate", "mo-fo", "mof0", "mofo", "mothafuck", "mothafucka", "mothafuckas", "mothafuckaz", "mothafucked", "mothafucker", "mothafuckers", "mothafuckin", "mothafucking", "mothafuckings", "mothafucks", "mother fucker", "motherfuck", "motherfucked", "motherfucker", "motherfuckers", "motherfuckin", "motherfucking", "motherfuckings", "motherfuckka", "motherfucks", "muff", "mutha", "muthafecker", "muthafuckker", "muther", "mutherfucker", "n1gga", "n1gger", "nazi", "nigg3r", "nigg4h", "nigga", "niggah", "niggas", "niggaz", "nigger", "niggers", "nob", "nob jokey", "nobhead", "nobjocky", "nobjokey", "numbnuts", "nutsack", "orgasim", "orgasims", "orgasm", "orgasms", "p0rn", "pawn", "pecker", "penis", "penisfucker", "phonesex", "phuck", "phuk", "phuked", "phuking", "phukked", "phukking", "phuks", "phuq", "pigfucker", "pimpis", "piss", "pissed", "pisser", "pissers", "pisses", "pissflaps", "pissin", "pissing", "pissoff", "poop", "porn", "porno", "pornography", "pornos", "prick", "pricks", "pron", "pube", "pusse", "pussi", "pussies", "pussy", "pussys", "rectum", "retard", "rimjaw", "rimming", "s hit", "s.o.b.", "sadist", "schlong", "screwing", "scroat", "scrote", "scrotum", "semen", "sex", "sh!+", "sh!t", "sh1t", "shag", "shagger", "shaggin", "shagging", "shemale", "shi+", "shit", "shitdick", "shite", "shited", "shitey", "shitfuck", "shitfull", "shithead", "shiting", "shitings", "shits", "shitted", "shitter", "shitters", "shitting", "shittings", "shitty", "skank", "slut", "sluts", "smegma", "smut", "snatch", "son-of-a-bitch", "spac", "spunk", "s_h_i_t", "t1tt1e5", "t1tties", "teets", "teez", "testical", "testicle", "tit", "titfuck", "tits", "titt", "tittie5", "tittiefucker", "titties", "tittyfuck", "tittywank", "titwank", "tosser", "turd", "tw4t", "twat", "twathead", "twatty", "twunt", "twunter", "v14gra", "v1gra", "vagina", "viagra", "vulva", "w00se", "wang", "wank", "wanker", "wanky", "whoar", "whore", "willies", "willy", "xrated", "xxx"]
examples/example_1_1.jpg ADDED
examples/example_1_2.jpg ADDED
examples/example_2_1.jpg ADDED
examples/example_2_2.jpg ADDED
examples/example_3_1.jpg ADDED
examples/example_3_2.jpg ADDED
examples/example_4_1.jpg ADDED
examples/example_4_2.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio
2
+ torch==1.12.1
3
+ torchvision==0.13.1
4
+ git+https://github.com/openai/CLIP.git
5
+ git+https://github.com/layer6ai-labs/tr0n.git