File size: 2,348 Bytes
95cc45b
 
 
 
 
 
 
 
448a859
95cc45b
 
 
 
 
 
 
 
 
 
448a859
 
 
 
 
 
 
 
 
 
 
 
95cc45b
448a859
 
95cc45b
 
 
 
 
448a859
 
95cc45b
448a859
 
 
 
 
 
 
 
 
 
95cc45b
 
 
 
 
 
 
 
448a859
95cc45b
 
 
448a859
95cc45b
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import gradio as gr
import torch
import torchvision
from diffusers import I2VGenXLPipeline, DiffusionPipeline
from torchvision.transforms.functional import to_tensor
from PIL import Image

if gr.NO_RELOAD:
    n_steps = 40
    high_noise_frac = 0.8
    negative_prompt = "Distorted, discontinuous, Ugly, blurry, low resolution, motionless, static, disfigured, disconnected limbs, Ugly faces, incomplete arms"
    generator = torch.manual_seed(8888)

    base = DiffusionPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0", 
        torch_dtype=torch.float16, 
        variant="fp16", 
        use_safetensors=True,
    )
    # refiner = DiffusionPipeline.from_pretrained(
    #     "stabilityai/stable-diffusion-xl-refiner-1.0",
    #     text_encoder_2=base.text_encoder_2,
    #     vae=base.vae,
    #     torch_dtype=torch.float16,
    #     use_safetensors=True,
    #     variant="fp16",
    # )
    # refiner.to("cuda")
    # base.to("cuda")
    # refiner.enable_model_cpu_offload()
    base.enable_model_cpu_offload()
    pipeline = I2VGenXLPipeline.from_pretrained("ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16")
    pipeline.enable_model_cpu_offload()
    pipeline.unet.enable_forward_chunking()

def generate(prompt: str):
    image = base(
        prompt=prompt,
        num_inference_steps=n_steps,
        # denoising_end=high_noise_frac,
        # output_type="latent",
    ).images[0]
    # image = refiner(
    #     prompt=prompt,
    #     num_inference_steps=n_steps,
    #     denoising_start=high_noise_frac,
    #     image=image,
    # ).images[0]
    # print(image)
    # print(type(image))
    # print(image.size())
    image.save("frame.jpg")
    image = to_tensor(image)
    frames: list[Image.Image] = pipeline(
        prompt=prompt,
        image=image,
        num_inference_steps=50,
        negative_prompt=negative_prompt,
        guidance_scale=9.0,
        generator=generator,
        decode_chunk_size=6,
    ).frames[0]
    frames = [to_tensor(frame.convert("RGB")).mul(255).byte().permute(1, 2, 0) for frame in frames]
    frames = torch.stack(frames)
    torchvision.io.write_video("video.mp4", frames, fps=4)
    return "video.mp4"

app = gr.Interface(
    fn=generate,
    inputs=["text"],
    outputs=gr.Video()
)

if __name__ == "__main__":
    app.launch()