Commit
•
702754c
1
Parent(s):
dae6484
Update app.py
Browse files
app.py
CHANGED
@@ -7,6 +7,8 @@ from huggingface_hub import snapshot_download
|
|
7 |
from pyramid_dit import PyramidDiTForVideoGeneration
|
8 |
from diffusers.utils import export_to_video
|
9 |
|
|
|
|
|
10 |
# Constants
|
11 |
MODEL_PATH = "pyramid-flow-model"
|
12 |
MODEL_REPO = "rain1011/pyramid-flow-sd3"
|
@@ -35,6 +37,7 @@ def load_model():
|
|
35 |
model = load_model()
|
36 |
|
37 |
# Text-to-video generation function
|
|
|
38 |
def generate_video(prompt, duration, guidance_scale, video_guidance_scale):
|
39 |
temp = int(duration * 2.4) # Convert seconds to temp value (assuming 24 FPS)
|
40 |
torch_dtype = torch.bfloat16 if MODEL_DTYPE == "bf16" else torch.float32
|
@@ -58,6 +61,7 @@ def generate_video(prompt, duration, guidance_scale, video_guidance_scale):
|
|
58 |
return output_path
|
59 |
|
60 |
# Image-to-video generation function
|
|
|
61 |
def generate_video_from_image(image, prompt, duration, video_guidance_scale):
|
62 |
temp = int(duration * 2.4) # Convert seconds to temp value (assuming 24 FPS)
|
63 |
torch_dtype = torch.bfloat16 if MODEL_DTYPE == "bf16" else torch.float32
|
|
|
7 |
from pyramid_dit import PyramidDiTForVideoGeneration
|
8 |
from diffusers.utils import export_to_video
|
9 |
|
10 |
+
import spaces
|
11 |
+
|
12 |
# Constants
|
13 |
MODEL_PATH = "pyramid-flow-model"
|
14 |
MODEL_REPO = "rain1011/pyramid-flow-sd3"
|
|
|
37 |
model = load_model()
|
38 |
|
39 |
# Text-to-video generation function
|
40 |
+
@spaces.GPU(duration=240)
|
41 |
def generate_video(prompt, duration, guidance_scale, video_guidance_scale):
|
42 |
temp = int(duration * 2.4) # Convert seconds to temp value (assuming 24 FPS)
|
43 |
torch_dtype = torch.bfloat16 if MODEL_DTYPE == "bf16" else torch.float32
|
|
|
61 |
return output_path
|
62 |
|
63 |
# Image-to-video generation function
|
64 |
+
@spaces.GPU(duration=240)
|
65 |
def generate_video_from_image(image, prompt, duration, video_guidance_scale):
|
66 |
temp = int(duration * 2.4) # Convert seconds to temp value (assuming 24 FPS)
|
67 |
torch_dtype = torch.bfloat16 if MODEL_DTYPE == "bf16" else torch.float32
|