Spaces:
Sleeping
Sleeping
Update demo/t2v.py
Browse files- demo/t2v.py +8 -2
demo/t2v.py
CHANGED
@@ -36,14 +36,14 @@ class Text2Video():
|
|
36 |
self.download_model(model_folder, model_filename)
|
37 |
if not os.path.isfile(os.path.join(model_folder, 'InternVideo2-stage2_1b-224p-f4.pt')):
|
38 |
self.download_internvideo2(model_folder)
|
39 |
-
self.agent = torch.load(os.path.join(model_folder, model_filename))
|
40 |
model_name = 'internvideo2'
|
41 |
|
42 |
# Get ViCLIP
|
43 |
viclip_global_instance = ViCLIPGlobalInstance(model_name)
|
44 |
if not viclip_global_instance._instantiated:
|
45 |
print("Instantiating InternVideo2")
|
46 |
-
viclip_global_instance.instantiate()
|
47 |
self.clip = viclip_global_instance.viclip
|
48 |
self.tokenizer = viclip_global_instance.viclip_tokenizer
|
49 |
|
@@ -51,8 +51,11 @@ class Text2Video():
|
|
51 |
if not os.path.exists(self.result_dir):
|
52 |
os.mkdir(self.result_dir)
|
53 |
|
|
|
54 |
def get_prompt(self, prompt, duration):
|
55 |
torch.cuda.empty_cache()
|
|
|
|
|
56 |
print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
|
57 |
start = time.time()
|
58 |
|
@@ -88,6 +91,9 @@ class Text2Video():
|
|
88 |
|
89 |
save_videos(prior_recon.unsqueeze(0), self.result_dir, filenames=[prompt_str], fps=15)
|
90 |
print(f"Saved in {prompt_str}.mp4. Time used: {(time.time() - start):.2f} seconds")
|
|
|
|
|
|
|
91 |
return os.path.join(self.result_dir, f"{prompt_str}.mp4")
|
92 |
|
93 |
def download_model(self, model_folder, model_filename):
|
|
|
36 |
self.download_model(model_folder, model_filename)
|
37 |
if not os.path.isfile(os.path.join(model_folder, 'InternVideo2-stage2_1b-224p-f4.pt')):
|
38 |
self.download_internvideo2(model_folder)
|
39 |
+
self.agent = torch.load(os.path.join(model_folder, model_filename),map_location='cpu')
|
40 |
model_name = 'internvideo2'
|
41 |
|
42 |
# Get ViCLIP
|
43 |
viclip_global_instance = ViCLIPGlobalInstance(model_name)
|
44 |
if not viclip_global_instance._instantiated:
|
45 |
print("Instantiating InternVideo2")
|
46 |
+
viclip_global_instance.instantiate(device='cpu')
|
47 |
self.clip = viclip_global_instance.viclip
|
48 |
self.tokenizer = viclip_global_instance.viclip_tokenizer
|
49 |
|
|
|
51 |
if not os.path.exists(self.result_dir):
|
52 |
os.mkdir(self.result_dir)
|
53 |
|
54 |
+
@spaces.GPU
|
55 |
def get_prompt(self, prompt, duration):
|
56 |
torch.cuda.empty_cache()
|
57 |
+
self.agent.to('cuda')
|
58 |
+
self.clip.to('cuda')
|
59 |
print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
|
60 |
start = time.time()
|
61 |
|
|
|
91 |
|
92 |
save_videos(prior_recon.unsqueeze(0), self.result_dir, filenames=[prompt_str], fps=15)
|
93 |
print(f"Saved in {prompt_str}.mp4. Time used: {(time.time() - start):.2f} seconds")
|
94 |
+
# Offload GPU
|
95 |
+
self.agent.to('cpu')
|
96 |
+
self.clip.to('cpu')
|
97 |
return os.path.join(self.result_dir, f"{prompt_str}.mp4")
|
98 |
|
99 |
def download_model(self, model_folder, model_filename):
|