bluestarburst commited on
Commit
9a6a590
1 Parent(s): 2057037

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. animatediff/models/motion_module.py +6 -1
  2. train.py +12 -0
animatediff/models/motion_module.py CHANGED
@@ -308,9 +308,14 @@ class VersatileAttention(CrossAttention):
308
  attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
309
  attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
310
 
 
 
 
 
311
  # attention, what we cannot get enough of
312
  if self._use_memory_efficient_attention_xformers:
313
- hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
 
314
  # Some versions of xformers return output in fp32, cast it back to the dtype of the input
315
  hidden_states = hidden_states.to(query.dtype)
316
  else:
 
308
  attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
309
  attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
310
 
311
+ if not hasattr(self, '_use_memory_efficient_attention_xformers'):
312
+ self._use_memory_efficient_attention_xformers = True
313
+
314
+
315
  # attention, what we cannot get enough of
316
  if self._use_memory_efficient_attention_xformers:
317
+ self.set_use_memory_efficient_attention_xformers(True)
318
+ # hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
319
  # Some versions of xformers return output in fp32, cast it back to the dtype of the input
320
  hidden_states = hidden_states.to(query.dtype)
321
  else:
train.py CHANGED
@@ -177,6 +177,7 @@ def main(
177
  for name, module in unet.named_modules():
178
  if "motion_modules" in name and (train_whole_module or name.endswith(tuple(trainable_modules))):
179
  for params in module.parameters():
 
180
  params.requires_grad = True
181
 
182
  if enable_xformers_memory_efficient_attention:
@@ -370,10 +371,21 @@ def main(
370
  avg_loss = accelerator.gather(loss.repeat(train_batch_size)).mean()
371
  train_loss += avg_loss.item() / gradient_accumulation_steps
372
 
 
 
 
 
 
373
  # Backpropagate
374
  accelerator.backward(loss)
375
  if accelerator.sync_gradients:
376
  accelerator.clip_grad_norm_(unet.parameters(), max_grad_norm)
 
 
 
 
 
 
377
  optimizer.step()
378
  lr_scheduler.step()
379
  optimizer.zero_grad()
 
177
  for name, module in unet.named_modules():
178
  if "motion_modules" in name and (train_whole_module or name.endswith(tuple(trainable_modules))):
179
  for params in module.parameters():
180
+ print("trainable", name)
181
  params.requires_grad = True
182
 
183
  if enable_xformers_memory_efficient_attention:
 
371
  avg_loss = accelerator.gather(loss.repeat(train_batch_size)).mean()
372
  train_loss += avg_loss.item() / gradient_accumulation_steps
373
 
374
+ for name, module in unet.named_modules():
375
+ if "motion_modules" in name and (train_whole_module or name.endswith(tuple(trainable_modules))):
376
+ for params in module.parameters():
377
+ params.requires_grad = True
378
+
379
  # Backpropagate
380
  accelerator.backward(loss)
381
  if accelerator.sync_gradients:
382
  accelerator.clip_grad_norm_(unet.parameters(), max_grad_norm)
383
+
384
+ # for param in unet.parameters():
385
+ # print(param.grad)
386
+
387
+
388
+
389
  optimizer.step()
390
  lr_scheduler.step()
391
  optimizer.zero_grad()