RohitGandikota commited on
Commit
991663d
Β·
1 Parent(s): 69ec7a4

disabling training for GPU<40GB

Browse files
app.py CHANGED
@@ -227,12 +227,6 @@ class Demo:
227
  )
228
 
229
  def train(self, target_concept,positive_prompt, negative_prompt, rank, iterations_input, lr_input, attributes_input, is_person, pbar = gr.Progress(track_tqdm=True)):
230
- # if target_concept is None:
231
- # target_concept = ''
232
- # if positive_prompt is None:
233
- # positive_prompt = ''
234
- # if negative_prompt is None:
235
- # negative_prompt = ''
236
 
237
  if attributes_input == '':
238
  attributes_input = None
@@ -244,6 +238,8 @@ class Demo:
244
  save_name += f'_noxattn'
245
  save_name += f'_rank_{rank}.pt'
246
 
 
 
247
  if self.training:
248
  return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()]
249
 
 
227
  )
228
 
229
  def train(self, target_concept,positive_prompt, negative_prompt, rank, iterations_input, lr_input, attributes_input, is_person, pbar = gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
230
 
231
  if attributes_input == '':
232
  attributes_input = None
 
238
  save_name += f'_noxattn'
239
  save_name += f'_rank_{rank}.pt'
240
 
241
+ if torch.cuda.get_device_properties(0).total_memory * 1e-9 < 40:
242
+ return [gr.update(interactive=True, value='Train'), gr.update(value='GPU Memory is not enough for training... Please upgrade to GPU atleast 40GB or clone the repo to your local machine.'), None, gr.update()]
243
  if self.training:
244
  return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()]
245
 
trainscripts/textsliders/data/config-xl.yaml CHANGED
@@ -9,7 +9,7 @@ network:
9
  alpha: 1.0
10
  training_method: "xattn"
11
  train:
12
- precision: "fp16"
13
  noise_scheduler: "ddim" # or "ddpm", "lms", "euler_a"
14
  iterations: 1000
15
  lr: 0.0002
@@ -20,9 +20,9 @@ save:
20
  name: "temp"
21
  path: "./models"
22
  per_steps: 5000000
23
- precision: "fp32"
24
  logging:
25
  use_wandb: false
26
  verbose: false
27
  other:
28
- use_xformers: true
 
9
  alpha: 1.0
10
  training_method: "xattn"
11
  train:
12
+ precision: "bfloat16"
13
  noise_scheduler: "ddim" # or "ddpm", "lms", "euler_a"
14
  iterations: 1000
15
  lr: 0.0002
 
20
  name: "temp"
21
  path: "./models"
22
  per_steps: 5000000
23
+ precision: "bfloat16"
24
  logging:
25
  use_wandb: false
26
  verbose: false
27
  other:
28
+ use_xformers: true