File size: 1,834 Bytes
898c6fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import gradio as gr
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import time
import json
import os

from src.video_crafter import VideoCrafterPipeline
from src.tools import DistController
from src.video_infinity.wrapper import DistWrapper

def init_pipeline(config):
    pipe = VideoCrafterPipeline.from_pretrained(
        'adamdad/videocrafterv2_diffusers',
        torch_dtype=torch.float16
    )
    pipe.enable_model_cpu_offload(
        gpu_id=config["devices"][dist.get_rank() % len(config["devices"])],
    )
    pipe.enable_vae_slicing()
    return pipe

def run_inference(prompt, config):
    dist_controller = DistController(0, 1, config)
    pipe = init_pipeline(config)
    dist_pipe = DistWrapper(pipe, dist_controller, config)
    pipe_configs = config['pipe_configs']
    plugin_configs = config['plugin_configs']

    start = time.time()
    video_path = dist_pipe.inference(
        prompt,
        config,
        pipe_configs,
        plugin_configs,
        additional_info={
            "full_config": config,
        }
    )
    print(f"Inference finished. Time: {time.time() - start}")
    return video_path

def demo(input_text):
    base_path = "./results"
    
    if not os.path.exists(base_path):
        os.makedirs(base_path)
    
    config = {
        "devices": [0],  # Укажите индексы ваших GPU, например [0] для одной GPU или [0, 1] для двух
        "base_path": base_path,  # Указываем путь, где будут сохраняться видео
        "pipe_configs": {
            "prompts": [input_text]
        },
        "plugin_configs": {}
    }
    video_path = run_inference(input_text, config)
    return video_path

iface = gr.Interface(fn=demo, inputs="text", outputs="video")
iface.launch()