Spaces:
Runtime error
Runtime error
feat(src): :rocket: First commit
Browse files- app.py +86 -4
- requirements.txt +6 -0
app.py
CHANGED
@@ -1,7 +1,89 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
2 |
|
3 |
-
def
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
|
4 |
+
from diffusers.utils import export_to_video
|
5 |
|
6 |
+
def generate_video(
|
7 |
+
prompt,
|
8 |
+
height=720,
|
9 |
+
width=1280,
|
10 |
+
num_frames=129,
|
11 |
+
num_inference_steps=30,
|
12 |
+
guidance_scale=6.0,
|
13 |
+
flow_shift=7.0
|
14 |
+
):
|
15 |
+
"""
|
16 |
+
Funci贸n para generar video usando HunyuanVideo
|
17 |
+
"""
|
18 |
+
try:
|
19 |
+
# Cargar el modelo transformer con bfloat16
|
20 |
+
model_id = "hunyuanvideo-community/HunyuanVideo"
|
21 |
+
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
|
22 |
+
model_id,
|
23 |
+
subfolder="transformer",
|
24 |
+
torch_dtype=torch.bfloat16
|
25 |
+
)
|
26 |
|
27 |
+
# Crear el pipeline con float16
|
28 |
+
pipe = HunyuanVideoPipeline.from_pretrained(
|
29 |
+
model_id,
|
30 |
+
transformer=transformer,
|
31 |
+
torch_dtype=torch.float16
|
32 |
+
)
|
33 |
+
|
34 |
+
# Habilitar tiling del VAE para ahorrar memoria
|
35 |
+
pipe.vae.enable_tiling()
|
36 |
+
|
37 |
+
# Mover el modelo a GPU si est谩 disponible
|
38 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
39 |
+
pipe.to(device)
|
40 |
+
|
41 |
+
# Generar el video
|
42 |
+
output = pipe(
|
43 |
+
prompt=prompt,
|
44 |
+
height=height,
|
45 |
+
width=width,
|
46 |
+
num_frames=num_frames,
|
47 |
+
num_inference_steps=num_inference_steps,
|
48 |
+
guidance_scale=guidance_scale,
|
49 |
+
flow_shift=flow_shift
|
50 |
+
).frames[0]
|
51 |
+
|
52 |
+
# Exportar a video
|
53 |
+
video_path = "output.mp4"
|
54 |
+
export_to_video(output, video_path, fps=15)
|
55 |
+
|
56 |
+
return video_path
|
57 |
+
|
58 |
+
except Exception as e:
|
59 |
+
return str(e)
|
60 |
+
|
61 |
+
# Crear la interfaz de Gradio
|
62 |
+
demo = gr.Interface(
|
63 |
+
fn=generate_video,
|
64 |
+
inputs=[
|
65 |
+
gr.Textbox(label="Prompt", placeholder="Describe el video que quieres generar..."),
|
66 |
+
gr.Slider(minimum=320, maximum=1280, value=720, step=16, label="Alto"),
|
67 |
+
gr.Slider(minimum=320, maximum=1280, value=1280, step=16, label="Ancho"),
|
68 |
+
gr.Slider(minimum=61, maximum=129, value=129, step=4, label="N煤mero de frames"),
|
69 |
+
gr.Slider(minimum=1, maximum=50, value=30, label="Pasos de inferencia"),
|
70 |
+
gr.Slider(minimum=1.0, maximum=20.0, value=6.0, label="Guidance Scale"),
|
71 |
+
gr.Slider(minimum=2.0, maximum=12.0, value=7.0, label="Flow Shift")
|
72 |
+
],
|
73 |
+
outputs=gr.Video(label="Video generado"),
|
74 |
+
title="Generador de Videos con HunyuanVideo",
|
75 |
+
description="""
|
76 |
+
Genera videos a partir de descripciones de texto usando HunyuanVideo.
|
77 |
+
- El prompt debe ser una descripci贸n clara del video que deseas generar
|
78 |
+
- Se recomienda usar resoluciones soportadas (ver tabla en la documentaci贸n)
|
79 |
+
- El n煤mero de frames debe ser de la forma 4k + 1 (ej: 61, 129)
|
80 |
+
- Flow shift: usar valores bajos (2-5) para resoluciones peque帽as y altos (7-12) para resoluciones grandes
|
81 |
+
""",
|
82 |
+
examples=[
|
83 |
+
["A cat walks on the grass, realistic style.", 720, 1280, 61, 30, 6.0, 7.0],
|
84 |
+
["A beautiful sunset over the ocean, cinematic style.", 544, 960, 129, 30, 6.0, 5.0]
|
85 |
+
]
|
86 |
+
)
|
87 |
+
|
88 |
+
if __name__ == "__main__":
|
89 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=2.0.0
|
2 |
+
torchvision>=0.15.0
|
3 |
+
diffusers>=0.25.0
|
4 |
+
transformers>=4.36.0
|
5 |
+
accelerate>=0.25.0
|
6 |
+
gradio>=4.0.0
|