JadenFK commited on
Commit
81ccbca
1 Parent(s): aaad790

More refactoring

Browse files
Files changed (2) hide show
  1. app.py +22 -107
  2. train.py +84 -0
app.py CHANGED
@@ -1,11 +1,9 @@
1
- from pathlib import Path
2
-
3
  import gradio as gr
4
  import torch
5
  from finetuning import FineTunedModel
6
  from StableDiffuser import StableDiffuser
7
  from tqdm import tqdm
8
-
9
 
10
  model_map = {
11
  'Car' : 'models/car.pt',
@@ -18,41 +16,16 @@ class Demo:
18
  def __init__(self) -> None:
19
 
20
  self.training = False
21
- self.generating = False
22
- self.nsteps = 50
23
 
24
- self.diffuser = StableDiffuser(scheduler='DDIM', seed=42).to('cuda')
25
- self.finetuner = None
26
-
27
 
28
  with gr.Blocks() as demo:
29
  self.layout()
30
- self.switch_model(self.model_dropdown.value)
31
-
32
- self.finetuner = self.finetuner.eval().half()
33
- self.diffuser = self.diffuser.eval().half()
34
-
35
  demo.queue(concurrency_count=2).launch()
36
 
37
- def disable(self):
38
- return [gr.update(interactive=False), gr.update(interactive=False)]
39
-
40
- def switch_model(self, model_name):
41
-
42
- if not model_name:
43
- return
44
-
45
- model_path = model_map[model_name]
46
-
47
- checkpoint = torch.load(model_path)
48
-
49
- self.finetuner = FineTunedModel.from_checkpoint(self.diffuser, checkpoint)
50
-
51
- torch.cuda.empty_cache()
52
 
53
  def layout(self):
54
 
55
-
56
  with gr.Row():
57
 
58
 
@@ -149,25 +122,24 @@ class Demo:
149
 
150
  with gr.Column(scale=1):
151
 
 
 
152
  self.train_button = gr.Button(
153
  value="Train",
154
  )
155
 
156
  self.download = gr.Files()
157
 
158
- self.model_dropdown.change(self.switch_model, inputs=[self.model_dropdown])
159
  self.infr_button.click(self.inference, inputs = [
160
  self.prompt_input_infr,
161
- self.seed_infr
 
162
  ],
163
  outputs=[
164
  self.image_new,
165
  self.image_orig
166
  ]
167
  )
168
- self.train_button.click(self.disable,
169
- outputs=[self.train_button, self.infr_button]
170
- )
171
  self.train_button.click(self.train, inputs = [
172
  self.prompt_input,
173
  self.train_method_input,
@@ -175,21 +147,13 @@ class Demo:
175
  self.iterations_input,
176
  self.lr_input
177
  ],
178
- outputs=[self.train_button, self.infr_button, self.download, self.model_dropdown]
179
  )
180
 
181
  def train(self, prompt, train_method, neg_guidance, iterations, lr, pbar = gr.Progress(track_tqdm=True)):
182
 
183
  if self.training:
184
- return [None, None, None]
185
- else:
186
- self.training = True
187
-
188
- del self.finetuner
189
-
190
- torch.cuda.empty_cache()
191
-
192
- self.diffuser = self.diffuser.train().float()
193
 
194
  if train_method == 'ESD-x':
195
 
@@ -206,82 +170,35 @@ class Demo:
206
  modules = ".*attn1$"
207
  frozen = []
208
 
209
- finetuner = FineTunedModel(self.diffuser, modules, frozen_modules=frozen)
210
-
211
- optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr)
212
- criteria = torch.nn.MSELoss()
213
-
214
- pbar = tqdm(range(iterations))
215
 
216
- with torch.no_grad():
217
 
218
- neutral_text_embeddings = self.diffuser.get_text_embeddings([''],n_imgs=1)
219
- positive_text_embeddings = self.diffuser.get_text_embeddings([prompt],n_imgs=1)
220
 
221
- for i in pbar:
222
-
223
- with torch.no_grad():
224
-
225
- self.diffuser.set_scheduler_timesteps(self.nsteps)
226
-
227
- optimizer.zero_grad()
228
-
229
- iteration = torch.randint(1, self.nsteps - 1, (1,)).item()
230
-
231
- latents = self.diffuser.get_initial_latents(1, 512, 1)
232
-
233
- with finetuner:
234
 
235
- latents_steps, _ = self.diffuser.diffusion(
236
- latents,
237
- positive_text_embeddings,
238
- start_iteration=0,
239
- end_iteration=iteration,
240
- guidance_scale=3,
241
- show_progress=False
242
- )
243
 
244
- self.diffuser.set_scheduler_timesteps(1000)
245
 
246
- iteration = int(iteration / self.nsteps * 1000)
247
-
248
- positive_latents = self.diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
249
- neutral_latents = self.diffuser.predict_noise(iteration, latents_steps[0], neutral_text_embeddings, guidance_scale=1)
250
 
251
- with finetuner:
252
- negative_latents = self.diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
253
 
254
- positive_latents.requires_grad = False
255
- neutral_latents.requires_grad = False
256
 
257
- loss = criteria(negative_latents, neutral_latents - (neg_guidance*(positive_latents - neutral_latents))) #loss = criteria(e_n, e_0) works the best try 5000 epochs
258
- loss.backward()
259
- optimizer.step()
260
 
261
- ft_path = f"{prompt.lower().replace(' ', '')}.pt"
262
- torch.save(finetuner.state_dict(), ft_path)
263
 
264
- self.finetuner = finetuner.eval().half()
 
 
265
 
266
- self.diffuser = self.diffuser.eval().half()
267
 
268
  torch.cuda.empty_cache()
269
 
270
- self.training = False
271
-
272
- model_map['Custom'] = ft_path
273
-
274
- return [gr.update(interactive=True), gr.update(interactive=True), ft_path, gr.Dropdown.update(choices=list(model_map.keys()), value='Custom')]
275
-
276
-
277
- def inference(self, prompt, seed, pbar = gr.Progress(track_tqdm=True)):
278
- if self.generating:
279
- return [None, None]
280
- else:
281
- self.generating = True
282
-
283
- self.diffuser._seed = seed or 42
284
-
285
  images = self.diffuser(
286
  prompt,
287
  n_steps=50,
@@ -302,8 +219,6 @@ class Demo:
302
 
303
  edited_image = images[0][0]
304
 
305
- self.generating = False
306
-
307
  torch.cuda.empty_cache()
308
 
309
  return edited_image, orig_image
 
 
 
1
  import gradio as gr
2
  import torch
3
  from finetuning import FineTunedModel
4
  from StableDiffuser import StableDiffuser
5
  from tqdm import tqdm
6
+ from train import train
7
 
8
  model_map = {
9
  'Car' : 'models/car.pt',
 
16
  def __init__(self) -> None:
17
 
18
  self.training = False
 
 
19
 
20
+ self.diffuser = StableDiffuser(scheduler='DDIM', seed=42).to('cuda').eval().half()
 
 
21
 
22
  with gr.Blocks() as demo:
23
  self.layout()
 
 
 
 
 
24
  demo.queue(concurrency_count=2).launch()
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  def layout(self):
28
 
 
29
  with gr.Row():
30
 
31
 
 
122
 
123
  with gr.Column(scale=1):
124
 
125
+ self.train_status = gr.Button(value='', variant='primary', label='Status', interactive=False)
126
+
127
  self.train_button = gr.Button(
128
  value="Train",
129
  )
130
 
131
  self.download = gr.Files()
132
 
 
133
  self.infr_button.click(self.inference, inputs = [
134
  self.prompt_input_infr,
135
+ self.seed_infr,
136
+ self.model_dropdown
137
  ],
138
  outputs=[
139
  self.image_new,
140
  self.image_orig
141
  ]
142
  )
 
 
 
143
  self.train_button.click(self.train, inputs = [
144
  self.prompt_input,
145
  self.train_method_input,
 
147
  self.iterations_input,
148
  self.lr_input
149
  ],
150
+ outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
151
  )
152
 
153
  def train(self, prompt, train_method, neg_guidance, iterations, lr, pbar = gr.Progress(track_tqdm=True)):
154
 
155
  if self.training:
156
+ return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()]
 
 
 
 
 
 
 
 
157
 
158
  if train_method == 'ESD-x':
159
 
 
170
  modules = ".*attn1$"
171
  frozen = []
172
 
173
+ randn = torch.randint(1, 10000000, (1,)).item()
 
 
 
 
 
174
 
175
+ save_path = f"models/{randn}_{prompt.lower().replace(' ', '')}.pt"
176
 
177
+ self.training = True
 
178
 
179
+ train(prompt, modules, frozen, iterations, neg_guidance, lr, save_path)
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
+ self.training = False
 
 
 
 
 
 
 
182
 
183
+ torch.cuda.empty_cache()
184
 
185
+ model_map['Custom'] = save_path
 
 
 
186
 
187
+ return [gr.update(interactive=True, value='Train'), gr.update(value='Done Training'), save_path, gr.Dropdown.update(choices=list(model_map.keys()), value='Custom')]
 
188
 
 
 
189
 
190
+ def inference(self, prompt, seed, model_name, pbar = gr.Progress(track_tqdm=True)):
 
 
191
 
192
+ self.diffuser._seed = seed or 42
 
193
 
194
+ model_path = model_map[model_name]
195
+
196
+ checkpoint = torch.load(model_path)
197
 
198
+ self.finetuner = FineTunedModel.from_checkpoint(self.diffuser, checkpoint).eval().half()
199
 
200
  torch.cuda.empty_cache()
201
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  images = self.diffuser(
203
  prompt,
204
  n_steps=50,
 
219
 
220
  edited_image = images[0][0]
221
 
 
 
222
  torch.cuda.empty_cache()
223
 
224
  return edited_image, orig_image
train.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from StableDiffuser import StableDiffuser
2
+ from finetuning import FineTunedModel
3
+ import torch
4
+ from tqdm import tqdm
5
+
6
+ def train(prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path):
7
+
8
+ nsteps = 50
9
+
10
+ diffuser = StableDiffuser(scheduler='DDIM').to('cuda')
11
+ diffuser.train()
12
+
13
+ finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules)
14
+
15
+ optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr)
16
+ criteria = torch.nn.MSELoss()
17
+
18
+ pbar = tqdm(range(iterations))
19
+
20
+ with torch.no_grad():
21
+
22
+ neutral_text_embeddings = diffuser.get_text_embeddings([''],n_imgs=1)
23
+ positive_text_embeddings = diffuser.get_text_embeddings([prompt],n_imgs=1)
24
+
25
+ losses = []
26
+
27
+ for i in pbar:
28
+
29
+ with torch.no_grad():
30
+
31
+ diffuser.set_scheduler_timesteps(nsteps)
32
+
33
+ optimizer.zero_grad()
34
+
35
+ iteration = torch.randint(1, nsteps - 1, (1,)).item()
36
+
37
+ latents = diffuser.get_initial_latents(1, 512, 1)
38
+
39
+ with finetuner:
40
+
41
+ latents_steps, _ = diffuser.diffusion(
42
+ latents,
43
+ positive_text_embeddings,
44
+ start_iteration=0,
45
+ end_iteration=iteration,
46
+ guidance_scale=3,
47
+ show_progress=False
48
+ )
49
+
50
+ diffuser.set_scheduler_timesteps(1000)
51
+
52
+ iteration = int(iteration / nsteps * 1000)
53
+
54
+ positive_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
55
+ neutral_latents = diffuser.predict_noise(iteration, latents_steps[0], neutral_text_embeddings, guidance_scale=1)
56
+
57
+ with finetuner:
58
+ negative_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
59
+
60
+ positive_latents.requires_grad = False
61
+ neutral_latents.requires_grad = False
62
+
63
+ loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents))) #loss = criteria(e_n, e_0) works the best try 5000 epochs
64
+ loss.backward()
65
+ losses.append(loss.item())
66
+ optimizer.step()
67
+
68
+ torch.save(finetuner.state_dict(), save_path)
69
+
70
+ if __name__ == '__main__':
71
+
72
+ import argparse
73
+
74
+ parser = argparse.ArgumentParser()
75
+
76
+ parser.add_argument('--prompt', required=True)
77
+ parser.add_argument('--modules', required=True)
78
+ parser.add_argument('--freeze_modules', nargs='+', required=True)
79
+ parser.add_argument('--save_path', required=True)
80
+ parser.add_argument('--iterations', type=int, required=True)
81
+ parser.add_argument('--lr', type=float, required=True)
82
+ parser.add_argument('--negative_guidance', type=float, required=True)
83
+
84
+ train(**vars(parser.parse_args()))