Video-Infinity / app.py
LokasNori's picture
Create app.py
898c6fa verified
raw
history blame
No virus
1.83 kB
import gradio as gr
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import time
import json
import os
from src.video_crafter import VideoCrafterPipeline
from src.tools import DistController
from src.video_infinity.wrapper import DistWrapper
def init_pipeline(config):
pipe = VideoCrafterPipeline.from_pretrained(
'adamdad/videocrafterv2_diffusers',
torch_dtype=torch.float16
)
pipe.enable_model_cpu_offload(
gpu_id=config["devices"][dist.get_rank() % len(config["devices"])],
)
pipe.enable_vae_slicing()
return pipe
def run_inference(prompt, config):
dist_controller = DistController(0, 1, config)
pipe = init_pipeline(config)
dist_pipe = DistWrapper(pipe, dist_controller, config)
pipe_configs = config['pipe_configs']
plugin_configs = config['plugin_configs']
start = time.time()
video_path = dist_pipe.inference(
prompt,
config,
pipe_configs,
plugin_configs,
additional_info={
"full_config": config,
}
)
print(f"Inference finished. Time: {time.time() - start}")
return video_path
def demo(input_text):
base_path = "./results"
if not os.path.exists(base_path):
os.makedirs(base_path)
config = {
"devices": [0], # Укажите индексы ваших GPU, например [0] для одной GPU или [0, 1] для двух
"base_path": base_path, # Указываем путь, где будут сохраняться видео
"pipe_configs": {
"prompts": [input_text]
},
"plugin_configs": {}
}
video_path = run_inference(input_text, config)
return video_path
iface = gr.Interface(fn=demo, inputs="text", outputs="video")
iface.launch()