File size: 6,801 Bytes
153e804
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import argparse, os, sys, glob
import datetime, time
from omegaconf import OmegaConf
import math

import torch
from decord import VideoReader, cpu
import torchvision
from pytorch_lightning import seed_everything

from lvdm.samplers.ddim import DDIMSampler
from lvdm.utils.common_utils import instantiate_from_config
from lvdm.utils.saving_utils import tensor_to_mp4
from scripts.sample_text2video_adapter import load_model_checkpoint, adapter_guided_synthesis

import torchvision.transforms._transforms_video as transforms_video
from huggingface_hub import hf_hub_download


def load_video(filepath, frame_stride, video_size=(256,256), video_frames=16):
    info_str = ''
    vidreader = VideoReader(filepath, ctx=cpu(0), width=video_size[1], height=video_size[0])
    max_frames = len(vidreader)
    # auto 

    if frame_stride != 0:
        if frame_stride * (video_frames-1) >= max_frames:
            info_str += "Warning: The user-set frame rate makes the current video length not enough, we will set it to an adaptive frame rate.\n"
            frame_stride = 0
    if frame_stride == 0:
        frame_stride = max_frames / video_frames 
        # if temp_stride < 1:
            # info_str = "Warning: The length of the current input video is less than 16 frames, we will automatically fill to 16 frames for you.\n"
    if frame_stride > 100:
        frame_stride = 100
        info_str += "Warning: The current input video length is longer than 1600 frames, we will process only the first 1600 frames.\n"
    info_str += f"Frame Stride is set to {frame_stride}"
    frame_indices = [int(frame_stride*i) for i in range(video_frames)]
    frames = vidreader.get_batch(frame_indices)
        
    ## [t,h,w,c] -> [c,t,h,w]
    frame_tensor = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float()
    frame_tensor = (frame_tensor / 255. - 0.5) * 2    
    return frame_tensor, info_str

class VideoControl:
    def __init__(self, result_dir='./tmp/') -> None:
        self.savedir = result_dir
        self.download_model()
        config_path = "models/adapter_t2v_depth/model_config.yaml"
        ckpt_path = "models/base_t2v/model_rm_wtm.ckpt"
        adapter_ckpt = "models/adapter_t2v_depth/adapter_t2v_depth_rm_wtm.pth"
        if os.path.exists('/dev/shm/model_rm_wtm.ckpt'):
            ckpt_path='/dev/shm/model_rm_wtm.ckpt'
        config = OmegaConf.load(config_path)
        model_config = config.pop("model", OmegaConf.create())
        model = instantiate_from_config(model_config)
        model = model.to('cuda')
        assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
        model = load_model_checkpoint(model, ckpt_path, adapter_ckpt)
        model.eval()
        self.model = model

    def get_video(self, input_video, input_prompt, frame_stride=0, vc_steps=50, vc_cfg_scale=15.0, vc_eta=1.0, video_frames=16, resolution=256):
        torch.cuda.empty_cache()
        if resolution > 512:
            resolution = 512
        if resolution < 64:
            resolution = 64
        if video_frames > 64:
            video_frames = 64
        
        resolution = int(resolution//64)*64
        
        if vc_steps > 60:
            vc_steps = 60
        ## load video
        print("input video", input_video)
        info_str = ''
        try:
            h, w, c = VideoReader(input_video, ctx=cpu(0))[0].shape
        except:
            os.remove(input_video)
            return 'please input video', None, None, None

        if h > w:
            scale = h / resolution
        else:
            scale = w / resolution
        h = math.ceil(h / scale)
        w = math.ceil(w / scale)
        try:
            video, info_str = load_video(input_video, frame_stride, video_size=(h, w), video_frames=video_frames)
        except:
            os.remove(input_video)
            return 'load video error', None, None, None
        if h > w:
            w = int(w//64)*64
        else:
            h = int(h//64)*64
        spatial_transform = transforms_video.CenterCropVideo((h,w))
        video = spatial_transform(video)
        print('video shape', video.shape)

        rh, rw = h//8, w//8
        bs = 1
        channels = self.model.channels
        # frames = self.model.temporal_length
        frames = video_frames
        noise_shape = [bs, channels, frames, rh, rw]
        
        ## inference
        start = time.time()
        prompt = input_prompt
        video = video.unsqueeze(0).to("cuda")
        try:
            with torch.no_grad():
                batch_samples, batch_conds = adapter_guided_synthesis(self.model, prompt, video, noise_shape, n_samples=1, ddim_steps=vc_steps, ddim_eta=vc_eta, unconditional_guidance_scale=vc_cfg_scale)
        except:
            torch.cuda.empty_cache()
            info_str="OOM, please enter a smaller resolution or smaller frame num"
            return info_str, None, None, None
        batch_samples = batch_samples[0]
        os.makedirs(self.savedir, exist_ok=True)
        filename = prompt
        filename = filename.replace("/", "_slash_") if "/" in filename else filename
        filename = filename.replace(" ", "_") if " " in filename else filename
        if len(filename) > 200:
            filename = filename[:200]
        video_path = os.path.join(self.savedir, f'{filename}_sample.mp4')
        depth_path = os.path.join(self.savedir, f'{filename}_depth.mp4')
        origin_path = os.path.join(self.savedir, f'{filename}.mp4')
        tensor_to_mp4(video=video.detach().cpu(), savepath=origin_path, fps=8)
        tensor_to_mp4(video=batch_conds.detach().cpu(), savepath=depth_path, fps=8)
        tensor_to_mp4(video=batch_samples.detach().cpu(), savepath=video_path, fps=8)

        print(f"Saved in {video_path}. Time used: {(time.time() - start):.2f} seconds")
        # delete video
        (path, input_filename) = os.path.split(input_video)
        if input_filename != 'flamingo.mp4':
            os.remove(input_video)
            print('delete input video')
        # print(input_video)
        return info_str, origin_path, depth_path, video_path
    def download_model(self):
        REPO_ID = 'VideoCrafter/t2v-version-1-1'
        filename_list = ['models/base_t2v/model_rm_wtm.ckpt',
                         "models/adapter_t2v_depth/adapter_t2v_depth_rm_wtm.pth",
                         "models/adapter_t2v_depth/dpt_hybrid-midas.pt"
                        ]
        for filename in filename_list:
            if not os.path.exists(filename):
                hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./', local_dir_use_symlinks=False)




    

if __name__ == "__main__":
    vc = VideoControl('./result')
    info_str, video_path =  vc.get_video('input/flamingo.mp4',"An ostrich walking in the desert, photorealistic, 4k")