frankleeeee commited on
Commit
8ef4ac6
1 Parent(s): 8d0ae18

Upload STDiT

Browse files
Files changed (1) hide show
  1. modeling_stdit.py +2 -1
modeling_stdit.py CHANGED
@@ -148,7 +148,8 @@ class STDiT(PreTrainedModel):
148
  tpe = self.pos_embed_temporal
149
  else:
150
  tpe = None
151
- x = auto_grad_checkpoint(block, x, y, t0, y_lens, tpe)
 
152
 
153
  if self.enable_sequence_parallelism:
154
  x = gather_forward_split_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="up")
 
148
  tpe = self.pos_embed_temporal
149
  else:
150
  tpe = None
151
+ x = block(x, y, t0, y_lens, tpe)
152
+ # x = auto_grad_checkpoint(block, x, y, t0, y_lens, tpe)
153
 
154
  if self.enable_sequence_parallelism:
155
  x = gather_forward_split_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="up")