anonymous commited on
Commit
2896183
1 Parent(s): 0a4007d
Files changed (2) hide show
  1. app.py +24 -16
  2. src/ddim_v_hacked.py +5 -3
app.py CHANGED
@@ -303,6 +303,8 @@ def process1(*args):
303
  imgs = sorted(os.listdir(cfg.input_dir))
304
  imgs = [os.path.join(cfg.input_dir, img) for img in imgs]
305
 
 
 
306
  with torch.no_grad():
307
  frame = cv2.imread(imgs[0])
308
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
@@ -607,6 +609,7 @@ def process2(*args):
607
 
608
  return key_video_path
609
 
 
610
  DESCRIPTION = '''
611
  ## Rerender A Video
612
  ### This space provides the function of key frame translation. Full code for full video translation will be released upon the publication of the paper.
@@ -644,12 +647,13 @@ with block:
644
  run_button3 = gr.Button(value='Run Propagation')
645
  with gr.Accordion('Advanced options for the 1st frame translation',
646
  open=False):
647
- image_resolution = gr.Slider(label='Frame rsolution',
648
- minimum=256,
649
- maximum=512,
650
- value=512,
651
- step=64,
652
- info='To avoid overload, maximum 512')
 
653
  control_strength = gr.Slider(label='ControNet strength',
654
  minimum=0.0,
655
  maximum=2.0,
@@ -734,12 +738,13 @@ with block:
734
  value=1,
735
  step=1,
736
  info='Uniformly sample the key frames every K frames')
737
- keyframe_count = gr.Slider(label='Number of key frames',
738
- minimum=1,
739
- maximum=1,
740
- value=1,
741
- step=1,
742
- info='To avoid overload, maximum 8 key frames')
 
743
 
744
  use_constraints = gr.CheckboxGroup(
745
  [
@@ -769,8 +774,10 @@ with block:
769
  maximum=100,
770
  value=1,
771
  step=1,
772
- info=('Update the key and value for '
773
- 'cross-frame attention every N key frames (recommend N*K>=10)'))
 
 
774
  with gr.Row():
775
  warp_start = gr.Slider(label='Shape-aware fusion start',
776
  minimum=0,
@@ -912,8 +919,9 @@ with block:
912
  run_button2.click(fn=process2, inputs=ips, outputs=[result_keyframe])
913
 
914
  def process3():
915
- raise gr.Error("Coming Soon. Full code for full video translation will be "
916
- "released upon the publication of the paper.")
 
917
 
918
  run_button3.click(fn=process3, outputs=[result_keyframe])
919
 
 
303
  imgs = sorted(os.listdir(cfg.input_dir))
304
  imgs = [os.path.join(cfg.input_dir, img) for img in imgs]
305
 
306
+ model.cond_stage_model.device = device
307
+
308
  with torch.no_grad():
309
  frame = cv2.imread(imgs[0])
310
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
 
609
 
610
  return key_video_path
611
 
612
+
613
  DESCRIPTION = '''
614
  ## Rerender A Video
615
  ### This space provides the function of key frame translation. Full code for full video translation will be released upon the publication of the paper.
 
647
  run_button3 = gr.Button(value='Run Propagation')
648
  with gr.Accordion('Advanced options for the 1st frame translation',
649
  open=False):
650
+ image_resolution = gr.Slider(
651
+ label='Frame rsolution',
652
+ minimum=256,
653
+ maximum=512,
654
+ value=512,
655
+ step=64,
656
+ info='To avoid overload, maximum 512')
657
  control_strength = gr.Slider(label='ControNet strength',
658
  minimum=0.0,
659
  maximum=2.0,
 
738
  value=1,
739
  step=1,
740
  info='Uniformly sample the key frames every K frames')
741
+ keyframe_count = gr.Slider(
742
+ label='Number of key frames',
743
+ minimum=1,
744
+ maximum=1,
745
+ value=1,
746
+ step=1,
747
+ info='To avoid overload, maximum 8 key frames')
748
 
749
  use_constraints = gr.CheckboxGroup(
750
  [
 
774
  maximum=100,
775
  value=1,
776
  step=1,
777
+ info=
778
+ ('Update the key and value for '
779
+ 'cross-frame attention every N key frames (recommend N*K>=10)'
780
+ ))
781
  with gr.Row():
782
  warp_start = gr.Slider(label='Shape-aware fusion start',
783
  minimum=0,
 
919
  run_button2.click(fn=process2, inputs=ips, outputs=[result_keyframe])
920
 
921
  def process3():
922
+ raise gr.Error(
923
+ "Coming Soon. Full code for full video translation will be "
924
+ "released upon the publication of the paper.")
925
 
926
  run_button3.click(fn=process3, outputs=[result_keyframe])
927
 
src/ddim_v_hacked.py CHANGED
@@ -14,6 +14,8 @@ from ControlNet.ldm.modules.diffusionmodules.util import (
14
 
15
  _ATTN_PRECISION = os.environ.get('ATTN_PRECISION', 'fp32')
16
 
 
 
17
 
18
  def register_attention_control(model, controller=None):
19
 
@@ -36,7 +38,7 @@ def register_attention_control(model, controller=None):
36
 
37
  # force cast to fp32 to avoid overflowing
38
  if _ATTN_PRECISION == 'fp32':
39
- with torch.autocast(enabled=False, device_type='cuda'):
40
  q, k = q.float(), k.float()
41
  sim = torch.einsum('b i d, b j d -> b i j', q,
42
  k) * self.scale
@@ -98,8 +100,8 @@ class DDIMVSampler(object):
98
 
99
  def register_buffer(self, name, attr):
100
  if type(attr) == torch.Tensor:
101
- if attr.device != torch.device('cuda'):
102
- attr = attr.to(torch.device('cuda'))
103
  setattr(self, name, attr)
104
 
105
  def make_schedule(self,
 
14
 
15
  _ATTN_PRECISION = os.environ.get('ATTN_PRECISION', 'fp32')
16
 
17
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
18
+
19
 
20
  def register_attention_control(model, controller=None):
21
 
 
38
 
39
  # force cast to fp32 to avoid overflowing
40
  if _ATTN_PRECISION == 'fp32':
41
+ with torch.autocast(enabled=False, device_type=device):
42
  q, k = q.float(), k.float()
43
  sim = torch.einsum('b i d, b j d -> b i j', q,
44
  k) * self.scale
 
100
 
101
  def register_buffer(self, name, attr):
102
  if type(attr) == torch.Tensor:
103
+ if attr.device != torch.device(device):
104
+ attr = attr.to(torch.device(device))
105
  setattr(self, name, attr)
106
 
107
  def make_schedule(self,