Spaces:
Runtime error
Runtime error
More refactoring
Browse files
app.py
CHANGED
@@ -1,11 +1,9 @@
|
|
1 |
-
from pathlib import Path
|
2 |
-
|
3 |
import gradio as gr
|
4 |
import torch
|
5 |
from finetuning import FineTunedModel
|
6 |
from StableDiffuser import StableDiffuser
|
7 |
from tqdm import tqdm
|
8 |
-
|
9 |
|
10 |
model_map = {
|
11 |
'Car' : 'models/car.pt',
|
@@ -18,41 +16,16 @@ class Demo:
|
|
18 |
def __init__(self) -> None:
|
19 |
|
20 |
self.training = False
|
21 |
-
self.generating = False
|
22 |
-
self.nsteps = 50
|
23 |
|
24 |
-
self.diffuser = StableDiffuser(scheduler='DDIM', seed=42).to('cuda')
|
25 |
-
self.finetuner = None
|
26 |
-
|
27 |
|
28 |
with gr.Blocks() as demo:
|
29 |
self.layout()
|
30 |
-
self.switch_model(self.model_dropdown.value)
|
31 |
-
|
32 |
-
self.finetuner = self.finetuner.eval().half()
|
33 |
-
self.diffuser = self.diffuser.eval().half()
|
34 |
-
|
35 |
demo.queue(concurrency_count=2).launch()
|
36 |
|
37 |
-
def disable(self):
|
38 |
-
return [gr.update(interactive=False), gr.update(interactive=False)]
|
39 |
-
|
40 |
-
def switch_model(self, model_name):
|
41 |
-
|
42 |
-
if not model_name:
|
43 |
-
return
|
44 |
-
|
45 |
-
model_path = model_map[model_name]
|
46 |
-
|
47 |
-
checkpoint = torch.load(model_path)
|
48 |
-
|
49 |
-
self.finetuner = FineTunedModel.from_checkpoint(self.diffuser, checkpoint)
|
50 |
-
|
51 |
-
torch.cuda.empty_cache()
|
52 |
|
53 |
def layout(self):
|
54 |
|
55 |
-
|
56 |
with gr.Row():
|
57 |
|
58 |
|
@@ -149,25 +122,24 @@ class Demo:
|
|
149 |
|
150 |
with gr.Column(scale=1):
|
151 |
|
|
|
|
|
152 |
self.train_button = gr.Button(
|
153 |
value="Train",
|
154 |
)
|
155 |
|
156 |
self.download = gr.Files()
|
157 |
|
158 |
-
self.model_dropdown.change(self.switch_model, inputs=[self.model_dropdown])
|
159 |
self.infr_button.click(self.inference, inputs = [
|
160 |
self.prompt_input_infr,
|
161 |
-
self.seed_infr
|
|
|
162 |
],
|
163 |
outputs=[
|
164 |
self.image_new,
|
165 |
self.image_orig
|
166 |
]
|
167 |
)
|
168 |
-
self.train_button.click(self.disable,
|
169 |
-
outputs=[self.train_button, self.infr_button]
|
170 |
-
)
|
171 |
self.train_button.click(self.train, inputs = [
|
172 |
self.prompt_input,
|
173 |
self.train_method_input,
|
@@ -175,21 +147,13 @@ class Demo:
|
|
175 |
self.iterations_input,
|
176 |
self.lr_input
|
177 |
],
|
178 |
-
outputs=[self.train_button,
|
179 |
)
|
180 |
|
181 |
def train(self, prompt, train_method, neg_guidance, iterations, lr, pbar = gr.Progress(track_tqdm=True)):
|
182 |
|
183 |
if self.training:
|
184 |
-
return [
|
185 |
-
else:
|
186 |
-
self.training = True
|
187 |
-
|
188 |
-
del self.finetuner
|
189 |
-
|
190 |
-
torch.cuda.empty_cache()
|
191 |
-
|
192 |
-
self.diffuser = self.diffuser.train().float()
|
193 |
|
194 |
if train_method == 'ESD-x':
|
195 |
|
@@ -206,82 +170,35 @@ class Demo:
|
|
206 |
modules = ".*attn1$"
|
207 |
frozen = []
|
208 |
|
209 |
-
|
210 |
-
|
211 |
-
optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr)
|
212 |
-
criteria = torch.nn.MSELoss()
|
213 |
-
|
214 |
-
pbar = tqdm(range(iterations))
|
215 |
|
216 |
-
|
217 |
|
218 |
-
|
219 |
-
positive_text_embeddings = self.diffuser.get_text_embeddings([prompt],n_imgs=1)
|
220 |
|
221 |
-
|
222 |
-
|
223 |
-
with torch.no_grad():
|
224 |
-
|
225 |
-
self.diffuser.set_scheduler_timesteps(self.nsteps)
|
226 |
-
|
227 |
-
optimizer.zero_grad()
|
228 |
-
|
229 |
-
iteration = torch.randint(1, self.nsteps - 1, (1,)).item()
|
230 |
-
|
231 |
-
latents = self.diffuser.get_initial_latents(1, 512, 1)
|
232 |
-
|
233 |
-
with finetuner:
|
234 |
|
235 |
-
|
236 |
-
latents,
|
237 |
-
positive_text_embeddings,
|
238 |
-
start_iteration=0,
|
239 |
-
end_iteration=iteration,
|
240 |
-
guidance_scale=3,
|
241 |
-
show_progress=False
|
242 |
-
)
|
243 |
|
244 |
-
|
245 |
|
246 |
-
|
247 |
-
|
248 |
-
positive_latents = self.diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
|
249 |
-
neutral_latents = self.diffuser.predict_noise(iteration, latents_steps[0], neutral_text_embeddings, guidance_scale=1)
|
250 |
|
251 |
-
|
252 |
-
negative_latents = self.diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
|
253 |
|
254 |
-
positive_latents.requires_grad = False
|
255 |
-
neutral_latents.requires_grad = False
|
256 |
|
257 |
-
|
258 |
-
loss.backward()
|
259 |
-
optimizer.step()
|
260 |
|
261 |
-
|
262 |
-
torch.save(finetuner.state_dict(), ft_path)
|
263 |
|
264 |
-
|
|
|
|
|
265 |
|
266 |
-
self.
|
267 |
|
268 |
torch.cuda.empty_cache()
|
269 |
|
270 |
-
self.training = False
|
271 |
-
|
272 |
-
model_map['Custom'] = ft_path
|
273 |
-
|
274 |
-
return [gr.update(interactive=True), gr.update(interactive=True), ft_path, gr.Dropdown.update(choices=list(model_map.keys()), value='Custom')]
|
275 |
-
|
276 |
-
|
277 |
-
def inference(self, prompt, seed, pbar = gr.Progress(track_tqdm=True)):
|
278 |
-
if self.generating:
|
279 |
-
return [None, None]
|
280 |
-
else:
|
281 |
-
self.generating = True
|
282 |
-
|
283 |
-
self.diffuser._seed = seed or 42
|
284 |
-
|
285 |
images = self.diffuser(
|
286 |
prompt,
|
287 |
n_steps=50,
|
@@ -302,8 +219,6 @@ class Demo:
|
|
302 |
|
303 |
edited_image = images[0][0]
|
304 |
|
305 |
-
self.generating = False
|
306 |
-
|
307 |
torch.cuda.empty_cache()
|
308 |
|
309 |
return edited_image, orig_image
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
from finetuning import FineTunedModel
|
4 |
from StableDiffuser import StableDiffuser
|
5 |
from tqdm import tqdm
|
6 |
+
from train import train
|
7 |
|
8 |
model_map = {
|
9 |
'Car' : 'models/car.pt',
|
|
|
16 |
def __init__(self) -> None:
|
17 |
|
18 |
self.training = False
|
|
|
|
|
19 |
|
20 |
+
self.diffuser = StableDiffuser(scheduler='DDIM', seed=42).to('cuda').eval().half()
|
|
|
|
|
21 |
|
22 |
with gr.Blocks() as demo:
|
23 |
self.layout()
|
|
|
|
|
|
|
|
|
|
|
24 |
demo.queue(concurrency_count=2).launch()
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
def layout(self):
|
28 |
|
|
|
29 |
with gr.Row():
|
30 |
|
31 |
|
|
|
122 |
|
123 |
with gr.Column(scale=1):
|
124 |
|
125 |
+
self.train_status = gr.Button(value='', variant='primary', label='Status', interactive=False)
|
126 |
+
|
127 |
self.train_button = gr.Button(
|
128 |
value="Train",
|
129 |
)
|
130 |
|
131 |
self.download = gr.Files()
|
132 |
|
|
|
133 |
self.infr_button.click(self.inference, inputs = [
|
134 |
self.prompt_input_infr,
|
135 |
+
self.seed_infr,
|
136 |
+
self.model_dropdown
|
137 |
],
|
138 |
outputs=[
|
139 |
self.image_new,
|
140 |
self.image_orig
|
141 |
]
|
142 |
)
|
|
|
|
|
|
|
143 |
self.train_button.click(self.train, inputs = [
|
144 |
self.prompt_input,
|
145 |
self.train_method_input,
|
|
|
147 |
self.iterations_input,
|
148 |
self.lr_input
|
149 |
],
|
150 |
+
outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
|
151 |
)
|
152 |
|
153 |
def train(self, prompt, train_method, neg_guidance, iterations, lr, pbar = gr.Progress(track_tqdm=True)):
|
154 |
|
155 |
if self.training:
|
156 |
+
return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
|
158 |
if train_method == 'ESD-x':
|
159 |
|
|
|
170 |
modules = ".*attn1$"
|
171 |
frozen = []
|
172 |
|
173 |
+
randn = torch.randint(1, 10000000, (1,)).item()
|
|
|
|
|
|
|
|
|
|
|
174 |
|
175 |
+
save_path = f"models/{randn}_{prompt.lower().replace(' ', '')}.pt"
|
176 |
|
177 |
+
self.training = True
|
|
|
178 |
|
179 |
+
train(prompt, modules, frozen, iterations, neg_guidance, lr, save_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
|
181 |
+
self.training = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
|
183 |
+
torch.cuda.empty_cache()
|
184 |
|
185 |
+
model_map['Custom'] = save_path
|
|
|
|
|
|
|
186 |
|
187 |
+
return [gr.update(interactive=True, value='Train'), gr.update(value='Done Training'), save_path, gr.Dropdown.update(choices=list(model_map.keys()), value='Custom')]
|
|
|
188 |
|
|
|
|
|
189 |
|
190 |
+
def inference(self, prompt, seed, model_name, pbar = gr.Progress(track_tqdm=True)):
|
|
|
|
|
191 |
|
192 |
+
self.diffuser._seed = seed or 42
|
|
|
193 |
|
194 |
+
model_path = model_map[model_name]
|
195 |
+
|
196 |
+
checkpoint = torch.load(model_path)
|
197 |
|
198 |
+
self.finetuner = FineTunedModel.from_checkpoint(self.diffuser, checkpoint).eval().half()
|
199 |
|
200 |
torch.cuda.empty_cache()
|
201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
images = self.diffuser(
|
203 |
prompt,
|
204 |
n_steps=50,
|
|
|
219 |
|
220 |
edited_image = images[0][0]
|
221 |
|
|
|
|
|
222 |
torch.cuda.empty_cache()
|
223 |
|
224 |
return edited_image, orig_image
|
train.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from StableDiffuser import StableDiffuser
|
2 |
+
from finetuning import FineTunedModel
|
3 |
+
import torch
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
def train(prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path):
|
7 |
+
|
8 |
+
nsteps = 50
|
9 |
+
|
10 |
+
diffuser = StableDiffuser(scheduler='DDIM').to('cuda')
|
11 |
+
diffuser.train()
|
12 |
+
|
13 |
+
finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules)
|
14 |
+
|
15 |
+
optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr)
|
16 |
+
criteria = torch.nn.MSELoss()
|
17 |
+
|
18 |
+
pbar = tqdm(range(iterations))
|
19 |
+
|
20 |
+
with torch.no_grad():
|
21 |
+
|
22 |
+
neutral_text_embeddings = diffuser.get_text_embeddings([''],n_imgs=1)
|
23 |
+
positive_text_embeddings = diffuser.get_text_embeddings([prompt],n_imgs=1)
|
24 |
+
|
25 |
+
losses = []
|
26 |
+
|
27 |
+
for i in pbar:
|
28 |
+
|
29 |
+
with torch.no_grad():
|
30 |
+
|
31 |
+
diffuser.set_scheduler_timesteps(nsteps)
|
32 |
+
|
33 |
+
optimizer.zero_grad()
|
34 |
+
|
35 |
+
iteration = torch.randint(1, nsteps - 1, (1,)).item()
|
36 |
+
|
37 |
+
latents = diffuser.get_initial_latents(1, 512, 1)
|
38 |
+
|
39 |
+
with finetuner:
|
40 |
+
|
41 |
+
latents_steps, _ = diffuser.diffusion(
|
42 |
+
latents,
|
43 |
+
positive_text_embeddings,
|
44 |
+
start_iteration=0,
|
45 |
+
end_iteration=iteration,
|
46 |
+
guidance_scale=3,
|
47 |
+
show_progress=False
|
48 |
+
)
|
49 |
+
|
50 |
+
diffuser.set_scheduler_timesteps(1000)
|
51 |
+
|
52 |
+
iteration = int(iteration / nsteps * 1000)
|
53 |
+
|
54 |
+
positive_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
|
55 |
+
neutral_latents = diffuser.predict_noise(iteration, latents_steps[0], neutral_text_embeddings, guidance_scale=1)
|
56 |
+
|
57 |
+
with finetuner:
|
58 |
+
negative_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
|
59 |
+
|
60 |
+
positive_latents.requires_grad = False
|
61 |
+
neutral_latents.requires_grad = False
|
62 |
+
|
63 |
+
loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents))) #loss = criteria(e_n, e_0) works the best try 5000 epochs
|
64 |
+
loss.backward()
|
65 |
+
losses.append(loss.item())
|
66 |
+
optimizer.step()
|
67 |
+
|
68 |
+
torch.save(finetuner.state_dict(), save_path)
|
69 |
+
|
70 |
+
if __name__ == '__main__':
|
71 |
+
|
72 |
+
import argparse
|
73 |
+
|
74 |
+
parser = argparse.ArgumentParser()
|
75 |
+
|
76 |
+
parser.add_argument('--prompt', required=True)
|
77 |
+
parser.add_argument('--modules', required=True)
|
78 |
+
parser.add_argument('--freeze_modules', nargs='+', required=True)
|
79 |
+
parser.add_argument('--save_path', required=True)
|
80 |
+
parser.add_argument('--iterations', type=int, required=True)
|
81 |
+
parser.add_argument('--lr', type=float, required=True)
|
82 |
+
parser.add_argument('--negative_guidance', type=float, required=True)
|
83 |
+
|
84 |
+
train(**vars(parser.parse_args()))
|