ychenhq commited on
Commit
069147a
1 Parent(s): 28103cd

Upload 5 files

Browse files
app.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
cog.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration for Cog ⚙️
2
+ # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
3
+
4
+ build:
5
+ gpu: true
6
+ system_packages:
7
+ - "libgl1-mesa-glx"
8
+ - "libglib2.0-0"
9
+ python_version: "3.11"
10
+ python_packages:
11
+ - "torch==2.0.1"
12
+ - "opencv-python==4.8.1.78"
13
+ - "torchvision==0.15.2"
14
+ - "pytorch_lightning==2.1.0"
15
+ - "einops==0.7.0"
16
+ - "imageio==2.31.6"
17
+ - "omegaconf==2.3.0"
18
+ - "transformers==4.35.0"
19
+ - "moviepy==1.0.3"
20
+ - "av==10.0.0"
21
+ - "decord==0.6.0"
22
+ - "kornia==0.7.0"
23
+ - "open-clip-torch==2.12.0"
24
+ - "xformers==0.0.21"
25
+ predict: "predict.py:Predictor"
final-year-project-443dd-df6f48af0796.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "type": "service_account",
3
+ "project_id": "final-year-project-443dd",
4
+ "private_key_id": "df6f48af0796ab27ae03fb99d08afca2ac2b00ef",
5
+ "private_key": "-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCmOla5Gssdx196\n4OZyrsE1so4q3nc1fWjNs9PsQ/cm6lTHTENAMM4yHbr0no4b5jL2KgBFwAsIMAMI\nzmHJNc+r/3dnLcPOvnUH8PlkaZNpH/5eQueLz8is7QcqvtnImkg/v2wlXLXWKwWx\nlWyvW10UuYry5qsta3aclqxmhP1jem6QnQxKLiQUNdAPbqsbFyEA11QHzivsTAac\nGdDHF2V/yJ05dqRE+40EaFYbzTXHUBglC0SbgGL512KvpSC16qwFBbY9oy+jHQ55\n8uzVVw5OCSmMCI+UmOrMSe/sI67jHXgOK/GexrHNazh2XbZUSupPIIz1lsBXUl1D\n8L3UdiWVAgMBAAECggEAJwZnOcnaicE230hRkfcJESw8SEA2SG6K3lArnrpOGerF\nwIxc9YL/xbBJJgjbYB1pNXWi3r05WdC7xaN+PZjOipjNVYHfCHiaTST7x+EpZHLI\nayTV63L6r+5t0lFAG+Jst9qe7x6W6hLroUdtXrXaYnU089XHtkAWdqjBDMiIHIRO\nZM9fAnCK/0dShYa0oD1BrjrGCUDrYdJ9I3WJWU+LHBfTZfLXEWbKeE+6665bC7IY\nB9JqhMlbNJWqNwIrg/bB8lI1qIGBY7lEl32N4cQ/JXXpOtfZGx7EAlYiez+bbgnI\nbJN637gp95E8V4l1eSDoF4FdIiygVcghXavOz+AHQQKBgQDmD8NjgkZQ9iiD+1kM\nJUi5AY+xgwOPfR+/vQSM2XWe5Q2jKOR82327Hj3bgua9pWr5FlPRFOakHIohV6nx\nFHkU9LVFwA9tL2pbs+kditDwg8doJtU/wpUW9kYhJ1MAY6dyuRr53CT4XIscXlKX\nHlOK5NClSNY0wFdgIxrQ3vGR/QKBgQC4+Cb2/Chsuh2jt0mp5IESYk38f9E4/YA3\n/1m8aQIbEUfhT3Xihk/MyhOp5MisnACt4kBH2KnrFzB1FAXtAgJQMvP2hLZekTQs\nhYMD2MfsT+E1Fj/bquIh4rDmrAW2wal+HzFBcuqBo81xXrokZGood9TnDNwwow1f\nMus3AXNJeQKBgGaVqtNpWL9rNB+96TQQQAA24QMPX3wRGCIgP7IqmVcT3ePeLRw7\npzHTx1NlaEwyQaP2P8OgZUPScglyFJYqQd+FSntiq75NAUkIzS7eIlLNABLCFh7L\nPj2x7Q2Fgm5PAXCXd57oehfA9ErfCEbYP/pUE3FQLCvzhEKbBK8UanVlAoGBAIkk\nPEedmB9dMwKir/ROHsDRsD7JSgf2NK3QHumJ9ey5uFC+iIoGyX3uSfwKTBtmoz5J\nZR2f8AQFMoFr8iTS+4IY9TdPGKQvBr8H0qb0gO6eHz0sHPay0W0MVdsBqk7hcdi4\nKd375RFvsLAg6uR2qxsMFgelSlCpZA20hB9JbQAJAoGAEmCK/A7k4AJq0cWtad3y\n9wmUsvGFZUhqj1nYtZ2GchKWIcszM28G77AnT52vPNjSDfygQAVxQ7NSYIcwULiA\nMHL4pB8RQr6P4yXISh7dPG8dlrhefrm4KdVMZPOz0Cpry4KejYWKx/YMjqZxARDd\nZFRtycZMdS8kBvSHeyc4mH8=\n-----END PRIVATE KEY-----\n",
6
+ "client_email": "firebase-adminsdk-74lss@final-year-project-443dd.iam.gserviceaccount.com",
7
+ "client_id": "104174452867915111710",
8
+ "auth_uri": "https://accounts.google.com/o/oauth2/auth",
9
+ "token_uri": "https://oauth2.googleapis.com/token",
10
+ "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
11
+ "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/firebase-adminsdk-74lss%40final-year-project-443dd.iam.gserviceaccount.com",
12
+ "universe_domain": "googleapis.com"
13
+ }
predict.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Prediction interface for Cog ⚙️
2
+ # https://github.com/replicate/cog/blob/main/docs/python.md
3
+
4
+
5
+ import os
6
+ import sys
7
+ import argparse
8
+ import random
9
+ from omegaconf import OmegaConf
10
+ from einops import rearrange, repeat
11
+ import torch
12
+ import torchvision
13
+ from pytorch_lightning import seed_everything
14
+ from cog import BasePredictor, Input, Path
15
+
16
+ sys.path.insert(0, "scripts/evaluation")
17
+ from funcs import (
18
+ batch_ddim_sampling,
19
+ load_model_checkpoint,
20
+ load_image_batch,
21
+ get_filelist,
22
+ )
23
+ from utils.utils import instantiate_from_config
24
+
25
+
26
+ class Predictor(BasePredictor):
27
+ def setup(self) -> None:
28
+ """Load the model into memory to make running multiple predictions efficient"""
29
+
30
+ ckpt_path_base = "checkpoints/base_1024_v1/model.ckpt"
31
+ config_base = "configs/inference_t2v_1024_v1.0.yaml"
32
+ ckpt_path_i2v = "checkpoints/i2v_512_v1/model.ckpt"
33
+ config_i2v = "configs/inference_i2v_512_v1.0.yaml"
34
+
35
+ config_base = OmegaConf.load(config_base)
36
+ model_config_base = config_base.pop("model", OmegaConf.create())
37
+ self.model_base = instantiate_from_config(model_config_base)
38
+ self.model_base = self.model_base.cuda()
39
+ self.model_base = load_model_checkpoint(self.model_base, ckpt_path_base)
40
+ self.model_base.eval()
41
+
42
+ config_i2v = OmegaConf.load(config_i2v)
43
+ model_config_i2v = config_i2v.pop("model", OmegaConf.create())
44
+ self.model_i2v = instantiate_from_config(model_config_i2v)
45
+ self.model_i2v = self.model_i2v.cuda()
46
+ self.model_i2v = load_model_checkpoint(self.model_i2v, ckpt_path_i2v)
47
+ self.model_i2v.eval()
48
+
49
+ def predict(
50
+ self,
51
+ task: str = Input(
52
+ description="Choose the task.",
53
+ choices=["text2video", "image2video"],
54
+ default="text2video",
55
+ ),
56
+ prompt: str = Input(
57
+ description="Prompt for video generation.",
58
+ default="A tiger walks in the forest, photorealistic, 4k, high definition.",
59
+ ),
60
+ image: Path = Input(
61
+ description="Input image for image2video task.", default=None
62
+ ),
63
+ ddim_steps: int = Input(description="Number of denoising steps.", default=50),
64
+ unconditional_guidance_scale: float = Input(
65
+ description="Classifier-free guidance scale.", default=12.0
66
+ ),
67
+ seed: int = Input(
68
+ description="Random seed. Leave blank to randomize the seed", default=None
69
+ ),
70
+ save_fps: int = Input(
71
+ description="Frame per second for the generated video.", default=10
72
+ ),
73
+ ) -> Path:
74
+
75
+ width = 1024 if task == "text2video" else 512
76
+ height = 576 if task == "text2video" else 320
77
+ model = self.model_base if task == "text2video" else self.model_i2v
78
+
79
+ if task == "image2video":
80
+ assert image is not None, "Please provide image for image2video generation."
81
+
82
+ if seed is None:
83
+ seed = int.from_bytes(os.urandom(2), "big")
84
+ print(f"Using seed: {seed}")
85
+ seed_everything(seed)
86
+
87
+ args = argparse.Namespace(
88
+ mode="base" if task == "text2video" else "i2v",
89
+ savefps=save_fps,
90
+ n_samples=1,
91
+ ddim_steps=ddim_steps,
92
+ ddim_eta=1.0,
93
+ bs=1,
94
+ height=height,
95
+ width=width,
96
+ frames=-1,
97
+ fps=28 if task == "text2video" else 8,
98
+ unconditional_guidance_scale=unconditional_guidance_scale,
99
+ unconditional_guidance_scale_temporal=None,
100
+ )
101
+
102
+ ## latent noise shape
103
+ h, w = args.height // 8, args.width // 8
104
+ frames = model.temporal_length if args.frames < 0 else args.frames
105
+ channels = model.channels
106
+
107
+ batch_size = 1
108
+ noise_shape = [batch_size, channels, frames, h, w]
109
+ fps = torch.tensor([args.fps] * batch_size).to(model.device).long()
110
+ prompts = [prompt]
111
+ text_emb = model.get_learned_conditioning(prompts)
112
+
113
+ if args.mode == "base":
114
+ cond = {"c_crossattn": [text_emb], "fps": fps}
115
+ elif args.mode == "i2v":
116
+ cond_images = load_image_batch([str(image)], (args.height, args.width))
117
+ cond_images = cond_images.to(model.device)
118
+ img_emb = model.get_image_embeds(cond_images)
119
+ imtext_cond = torch.cat([text_emb, img_emb], dim=1)
120
+ cond = {"c_crossattn": [imtext_cond], "fps": fps}
121
+ else:
122
+ raise NotImplementedError
123
+
124
+ ## inference
125
+ batch_samples = batch_ddim_sampling(
126
+ model,
127
+ cond,
128
+ noise_shape,
129
+ args.n_samples,
130
+ args.ddim_steps,
131
+ args.ddim_eta,
132
+ args.unconditional_guidance_scale,
133
+ )
134
+
135
+ out_path = "/tmp/output.mp4"
136
+ vid_tensor = batch_samples[0]
137
+ video = vid_tensor.detach().cpu()
138
+ video = torch.clamp(video.float(), -1.0, 1.0)
139
+ video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
140
+
141
+ frame_grids = [
142
+ torchvision.utils.make_grid(framesheet, nrow=int(args.n_samples))
143
+ for framesheet in video
144
+ ] # [3, 1*h, n*w]
145
+ grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w]
146
+ grid = (grid + 1.0) / 2.0
147
+ grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
148
+ torchvision.io.write_video(
149
+ out_path,
150
+ grid,
151
+ fps=args.savefps,
152
+ video_codec="h264",
153
+ options={"crf": "10"},
154
+ )
155
+ return Path(out_path)
requirements.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ decord==0.6.0
2
+ einops==0.3.0
3
+ imageio==2.9.0
4
+ numpy==1.24.2
5
+ omegaconf==2.1.1
6
+ opencv_python>=4.1.2
7
+ pandas==2.0.0
8
+ Pillow==9.5.0
9
+ pytorch_lightning==1.8.3
10
+ PyYAML==6.0
11
+ setuptools==65.6.3
12
+ torch==2.0.0
13
+ torchvision>=0.7.0
14
+ tqdm==4.65.0
15
+ transformers==4.25.1
16
+ moviepy>=1.0.3
17
+ av
18
+ xformers
19
+ gradio
20
+ timm
21
+ scikit-learn
22
+ open_clip_torch==2.22.0
23
+ kornia
24
+ sk-video>=1.1.10