JadenFK commited on
Commit
7021212
1 Parent(s): 46ccc0e

Memory saving

Browse files
Files changed (1) hide show
  1. train.py +12 -2
train.py CHANGED
@@ -10,6 +10,9 @@ def train(prompt, modules, freeze_modules, iterations, negative_guidance, lr, sa
10
  diffuser = StableDiffuser(scheduler='DDIM').to('cuda')
11
  diffuser.train()
12
 
 
 
 
13
  finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules)
14
 
15
  optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr)
@@ -22,7 +25,11 @@ def train(prompt, modules, freeze_modules, iterations, negative_guidance, lr, sa
22
  neutral_text_embeddings = diffuser.get_text_embeddings([''],n_imgs=1)
23
  positive_text_embeddings = diffuser.get_text_embeddings([prompt],n_imgs=1)
24
 
25
- losses = []
 
 
 
 
26
 
27
  for i in pbar:
28
 
@@ -61,8 +68,11 @@ def train(prompt, modules, freeze_modules, iterations, negative_guidance, lr, sa
61
  neutral_latents.requires_grad = False
62
 
63
  loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents))) #loss = criteria(e_n, e_0) works the best try 5000 epochs
 
 
 
 
64
  loss.backward()
65
- losses.append(loss.item())
66
  optimizer.step()
67
 
68
  torch.save(finetuner.state_dict(), save_path)
 
10
  diffuser = StableDiffuser(scheduler='DDIM').to('cuda')
11
  diffuser.train()
12
 
13
+
14
+
15
+
16
  finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules)
17
 
18
  optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr)
 
25
  neutral_text_embeddings = diffuser.get_text_embeddings([''],n_imgs=1)
26
  positive_text_embeddings = diffuser.get_text_embeddings([prompt],n_imgs=1)
27
 
28
+ del diffuser.vae
29
+ del diffuser.text_encoder
30
+ del diffuser.tokenizer
31
+
32
+ torch.cuda.empty_cache()
33
 
34
  for i in pbar:
35
 
 
68
  neutral_latents.requires_grad = False
69
 
70
  loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents))) #loss = criteria(e_n, e_0) works the best try 5000 epochs
71
+
72
+ del negative_latents, neutral_latents, positive_latents, latents_steps, latents
73
+ torch.cuda.empty_cache()
74
+
75
  loss.backward()
 
76
  optimizer.step()
77
 
78
  torch.save(finetuner.state_dict(), save_path)