fffiloni commited on
Commit
14eb1b8
1 Parent(s): 3801c88

Update models/pipeline_controlvideo.py

Browse files
Files changed (1) hide show
  1. models/pipeline_controlvideo.py +40 -2
models/pipeline_controlvideo.py CHANGED
@@ -652,7 +652,7 @@ class ControlVideoPipeline(DiffusionPipeline):
652
  prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
653
  alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod
654
  return alpha_prod_t_prev
655
-
656
  def get_slide_window_indices(self, video_length, window_size):
657
  assert window_size >=3
658
  key_frame_indices = np.arange(0, video_length, window_size-1).tolist()
@@ -668,7 +668,45 @@ class ControlVideoPipeline(DiffusionPipeline):
668
  continue
669
  inter_frame_list.append(s[1:].tolist())
670
  return key_frame_indices, inter_frame_list
671
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
672
  @torch.no_grad()
673
  def __call__(
674
  self,
 
652
  prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
653
  alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod
654
  return alpha_prod_t_prev
655
+ """
656
  def get_slide_window_indices(self, video_length, window_size):
657
  assert window_size >=3
658
  key_frame_indices = np.arange(0, video_length, window_size-1).tolist()
 
668
  continue
669
  inter_frame_list.append(s[1:].tolist())
670
  return key_frame_indices, inter_frame_list
671
+ """
672
+ def get_slide_window_indices(self, video_length, window_size):
673
+ assert window_size >= 3
674
+
675
+ # Define the chunk size for processing
676
+ chunk_size = 4
677
+
678
+ # Calculate the number of chunks
679
+ num_chunks = (video_length - 1) // chunk_size + 1
680
+
681
+ # Initialize the lists to store the results
682
+ key_frame_indices = []
683
+ inter_frame_list = []
684
+
685
+ for chunk_index in range(num_chunks):
686
+ # Calculate the start and end indices for the current chunk
687
+ start_index = chunk_index * chunk_size
688
+ end_index = min((chunk_index + 1) * chunk_size, video_length)
689
+
690
+ # Generate key frame indices for the current chunk
691
+ chunk_key_frame_indices = np.arange(start_index, end_index, window_size - 1).tolist()
692
+
693
+ # Append the last index if it's not already included
694
+ if chunk_key_frame_indices[-1] != (end_index - 1):
695
+ chunk_key_frame_indices.append(end_index - 1)
696
+
697
+ # Append the key frame indices of the current chunk to the overall list
698
+ key_frame_indices.extend(chunk_key_frame_indices)
699
+
700
+ # Generate slices for the current chunk
701
+ chunk_slices = np.split(np.arange(start_index, end_index), chunk_key_frame_indices)
702
+
703
+ # Process each slice in the current chunk
704
+ for s in chunk_slices:
705
+ if len(s) < 2:
706
+ continue
707
+ inter_frame_list.append(s[1:].tolist())
708
+
709
+ return key_frame_indices, inter_frame_list
710
  @torch.no_grad()
711
  def __call__(
712
  self,