mazpie commited on
Commit
f83208c
1 Parent(s): ae407cb

Update demo/t2v.py

Browse files
Files changed (1) hide show
  1. demo/t2v.py +8 -2
demo/t2v.py CHANGED
@@ -36,14 +36,14 @@ class Text2Video():
36
  self.download_model(model_folder, model_filename)
37
  if not os.path.isfile(os.path.join(model_folder, 'InternVideo2-stage2_1b-224p-f4.pt')):
38
  self.download_internvideo2(model_folder)
39
- self.agent = torch.load(os.path.join(model_folder, model_filename))
40
  model_name = 'internvideo2'
41
 
42
  # Get ViCLIP
43
  viclip_global_instance = ViCLIPGlobalInstance(model_name)
44
  if not viclip_global_instance._instantiated:
45
  print("Instantiating InternVideo2")
46
- viclip_global_instance.instantiate()
47
  self.clip = viclip_global_instance.viclip
48
  self.tokenizer = viclip_global_instance.viclip_tokenizer
49
 
@@ -51,8 +51,11 @@ class Text2Video():
51
  if not os.path.exists(self.result_dir):
52
  os.mkdir(self.result_dir)
53
 
 
54
  def get_prompt(self, prompt, duration):
55
  torch.cuda.empty_cache()
 
 
56
  print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
57
  start = time.time()
58
 
@@ -88,6 +91,9 @@ class Text2Video():
88
 
89
  save_videos(prior_recon.unsqueeze(0), self.result_dir, filenames=[prompt_str], fps=15)
90
  print(f"Saved in {prompt_str}.mp4. Time used: {(time.time() - start):.2f} seconds")
 
 
 
91
  return os.path.join(self.result_dir, f"{prompt_str}.mp4")
92
 
93
  def download_model(self, model_folder, model_filename):
 
36
  self.download_model(model_folder, model_filename)
37
  if not os.path.isfile(os.path.join(model_folder, 'InternVideo2-stage2_1b-224p-f4.pt')):
38
  self.download_internvideo2(model_folder)
39
+ self.agent = torch.load(os.path.join(model_folder, model_filename),map_location='cpu')
40
  model_name = 'internvideo2'
41
 
42
  # Get ViCLIP
43
  viclip_global_instance = ViCLIPGlobalInstance(model_name)
44
  if not viclip_global_instance._instantiated:
45
  print("Instantiating InternVideo2")
46
+ viclip_global_instance.instantiate(device='cpu')
47
  self.clip = viclip_global_instance.viclip
48
  self.tokenizer = viclip_global_instance.viclip_tokenizer
49
 
 
51
  if not os.path.exists(self.result_dir):
52
  os.mkdir(self.result_dir)
53
 
54
+ @spaces.GPU
55
  def get_prompt(self, prompt, duration):
56
  torch.cuda.empty_cache()
57
+ self.agent.to('cuda')
58
+ self.clip.to('cuda')
59
  print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
60
  start = time.time()
61
 
 
91
 
92
  save_videos(prior_recon.unsqueeze(0), self.result_dir, filenames=[prompt_str], fps=15)
93
  print(f"Saved in {prompt_str}.mp4. Time used: {(time.time() - start):.2f} seconds")
94
+ # Offload GPU
95
+ self.agent.to('cpu')
96
+ self.clip.to('cpu')
97
  return os.path.join(self.result_dir, f"{prompt_str}.mp4")
98
 
99
  def download_model(self, model_folder, model_filename):