File size: 4,705 Bytes
f949b3f
 
 
 
 
 
 
 
81022ab
f949b3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d67a615
f949b3f
 
 
d67a615
f949b3f
 
 
 
d67a615
 
 
 
 
f949b3f
75aaff7
f949b3f
 
 
 
 
 
 
 
d67a615
f949b3f
 
d67a615
f949b3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d67a615
 
f949b3f
 
 
 
 
 
 
 
 
 
 
 
 
d67a615
f949b3f
 
d67a615
f949b3f
 
 
d67a615
 
 
 
 
 
 
 
 
f949b3f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# General
import os
from os.path import join as opj
import datetime
import torch
from einops import rearrange, repeat

# Utilities
from t2v_enhanced.inference_utils import *

from modelscope.outputs import OutputKeys
import imageio
from PIL import Image
import numpy as np

import torch.nn.functional as F
import torchvision.transforms as transforms
from diffusers.utils import load_image
transform = transforms.Compose([ 
    transforms.PILToTensor() 
])


def ms_short_gen(prompt, ms_model, inference_generator, t=50, device="cuda"):
    frames = ms_model(prompt, 
                    num_inference_steps=t, 
                    generator=inference_generator, 
                    eta=1.0, 
                    height=256, 
                    width=256, 
                    latents=None).frames
    frames = torch.stack([torch.from_numpy(frame) for frame in frames])
    frames = frames.to(device).to(torch.float32)
    return rearrange(frames[0], "F W H C -> F C W H")

def ad_short_gen(prompt, ad_model, inference_generator, t=25, device="cuda"):
    frames = ad_model(prompt, 
                    negative_prompt="bad quality, worse quality",
                    num_frames=16,
                    num_inference_steps=t, 
                    generator=inference_generator, 
                    guidance_scale=7.5).frames[0]
    frames = torch.stack([transform(frame) for frame in frames])
    frames = frames.to(device).to(torch.float32)
    frames = F.interpolate(frames, size=256)
    frames = frames/255.0
    return frames

def sdxl_image_gen(prompt, sdxl_model):
    image = sdxl_model(prompt=prompt).images[0]
    return image

def svd_short_gen(image, prompt, svd_model, sdxl_model, inference_generator, t=25, device="cuda"):
    if image is None:
        image = sdxl_image_gen(prompt, sdxl_model)
        image = image.resize((576, 576))
        image = add_margin(image, 0, 224, 0, 224, (0, 0, 0))
    elif type(image) is str:
        image = load_image(image)
        image = resize_and_keep(image)
        image = center_crop(image)
        image = add_margin(image, 0, 224, 0, 224, (0, 0, 0))
    else:
        image = Image.fromarray(np.uint8(image))
        image = resize_and_keep(image)
        image = center_crop(image)
        image = add_margin(image, 0, 224, 0, 224, (0, 0, 0))

    frames = svd_model(image, decode_chunk_size=4, generator=inference_generator).frames[0]
    frames = torch.stack([transform(frame) for frame in frames])
    frames = frames.to(device).to(torch.float32)
    frames = frames[:16,:,:,224:-224]
    frames = F.interpolate(frames, size=256)
    frames = frames/255.0
    return frames


def stream_long_gen(prompt, short_video, n_autoreg_gen, seed, t, image_guidance, result_file_stem, stream_cli, stream_model):
    trainer = stream_cli.trainer
    trainer.limit_predict_batches = 1

    trainer.predict_cfg = {
        "predict_dir": stream_cli.config["result_fol"].as_posix(),
        "result_file_stem": result_file_stem,
        "prompt": prompt,
        "video": short_video,
        "seed": seed,
        "num_inference_steps": t,
        "guidance_scale": image_guidance,
        'n_autoregressive_generations': n_autoreg_gen,
    }

    trainer.predict(model=stream_model, datamodule=stream_cli.datamodule)


def video2video(prompt, video, where_to_log, cfg_v2v, model_v2v, square=True):
    downscale = cfg_v2v['downscale']
    upscale_size = cfg_v2v['upscale_size']
    pad = cfg_v2v['pad']

    now = datetime.datetime.now()
    now = str(now.time()).replace(":", "_").replace(".", "_")
    name = prompt[:100].replace(" ", "_") + "_" + now
    enhanced_video_mp4 = opj(where_to_log, name+"_enhanced.mp4")

    video_frames = imageio.mimread(video)
    h, w, _ = video_frames[0].shape

    # Downscale video, then resize to fit the upscale size
    video = [Image.fromarray(frame).resize((w//downscale, h//downscale)) for frame in video_frames]
    video = [resize_to_fit(frame, upscale_size) for frame in video]

    if pad:
        video = [pad_to_fit(frame, upscale_size) for frame in video]
    # video = [np.array(frame) for frame in video]

    imageio.mimsave(opj(where_to_log, 'temp_'+now+'.mp4'), video, fps=8)

    p_input = {
        'video_path': opj(where_to_log, 'temp_'+now+'.mp4'),
        'text': prompt
    }
    output_video_path = model_v2v(p_input, output_video=enhanced_video_mp4)[OutputKeys.OUTPUT_VIDEO]

    # Remove padding
    video_frames = imageio.mimread(enhanced_video_mp4)
    video_frames_square = []
    for frame in video_frames:
        frame = frame[:, 280:-280, :]
        video_frames_square.append(frame)
    imageio.mimsave(enhanced_video_mp4, video_frames_square)

    return enhanced_video_mp4