Spaces:
Runtime error
Runtime error
Upload 5 files
Browse files- app.ipynb +0 -0
- cog.yaml +25 -0
- final-year-project-443dd-df6f48af0796.json +13 -0
- predict.py +155 -0
- requirements.txt +24 -0
app.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
cog.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Configuration for Cog ⚙️
|
2 |
+
# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
|
3 |
+
|
4 |
+
build:
|
5 |
+
gpu: true
|
6 |
+
system_packages:
|
7 |
+
- "libgl1-mesa-glx"
|
8 |
+
- "libglib2.0-0"
|
9 |
+
python_version: "3.11"
|
10 |
+
python_packages:
|
11 |
+
- "torch==2.0.1"
|
12 |
+
- "opencv-python==4.8.1.78"
|
13 |
+
- "torchvision==0.15.2"
|
14 |
+
- "pytorch_lightning==2.1.0"
|
15 |
+
- "einops==0.7.0"
|
16 |
+
- "imageio==2.31.6"
|
17 |
+
- "omegaconf==2.3.0"
|
18 |
+
- "transformers==4.35.0"
|
19 |
+
- "moviepy==1.0.3"
|
20 |
+
- "av==10.0.0"
|
21 |
+
- "decord==0.6.0"
|
22 |
+
- "kornia==0.7.0"
|
23 |
+
- "open-clip-torch==2.12.0"
|
24 |
+
- "xformers==0.0.21"
|
25 |
+
predict: "predict.py:Predictor"
|
final-year-project-443dd-df6f48af0796.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"type": "service_account",
|
3 |
+
"project_id": "final-year-project-443dd",
|
4 |
+
"private_key_id": "df6f48af0796ab27ae03fb99d08afca2ac2b00ef",
|
5 |
+
"private_key": "-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCmOla5Gssdx196\n4OZyrsE1so4q3nc1fWjNs9PsQ/cm6lTHTENAMM4yHbr0no4b5jL2KgBFwAsIMAMI\nzmHJNc+r/3dnLcPOvnUH8PlkaZNpH/5eQueLz8is7QcqvtnImkg/v2wlXLXWKwWx\nlWyvW10UuYry5qsta3aclqxmhP1jem6QnQxKLiQUNdAPbqsbFyEA11QHzivsTAac\nGdDHF2V/yJ05dqRE+40EaFYbzTXHUBglC0SbgGL512KvpSC16qwFBbY9oy+jHQ55\n8uzVVw5OCSmMCI+UmOrMSe/sI67jHXgOK/GexrHNazh2XbZUSupPIIz1lsBXUl1D\n8L3UdiWVAgMBAAECggEAJwZnOcnaicE230hRkfcJESw8SEA2SG6K3lArnrpOGerF\nwIxc9YL/xbBJJgjbYB1pNXWi3r05WdC7xaN+PZjOipjNVYHfCHiaTST7x+EpZHLI\nayTV63L6r+5t0lFAG+Jst9qe7x6W6hLroUdtXrXaYnU089XHtkAWdqjBDMiIHIRO\nZM9fAnCK/0dShYa0oD1BrjrGCUDrYdJ9I3WJWU+LHBfTZfLXEWbKeE+6665bC7IY\nB9JqhMlbNJWqNwIrg/bB8lI1qIGBY7lEl32N4cQ/JXXpOtfZGx7EAlYiez+bbgnI\nbJN637gp95E8V4l1eSDoF4FdIiygVcghXavOz+AHQQKBgQDmD8NjgkZQ9iiD+1kM\nJUi5AY+xgwOPfR+/vQSM2XWe5Q2jKOR82327Hj3bgua9pWr5FlPRFOakHIohV6nx\nFHkU9LVFwA9tL2pbs+kditDwg8doJtU/wpUW9kYhJ1MAY6dyuRr53CT4XIscXlKX\nHlOK5NClSNY0wFdgIxrQ3vGR/QKBgQC4+Cb2/Chsuh2jt0mp5IESYk38f9E4/YA3\n/1m8aQIbEUfhT3Xihk/MyhOp5MisnACt4kBH2KnrFzB1FAXtAgJQMvP2hLZekTQs\nhYMD2MfsT+E1Fj/bquIh4rDmrAW2wal+HzFBcuqBo81xXrokZGood9TnDNwwow1f\nMus3AXNJeQKBgGaVqtNpWL9rNB+96TQQQAA24QMPX3wRGCIgP7IqmVcT3ePeLRw7\npzHTx1NlaEwyQaP2P8OgZUPScglyFJYqQd+FSntiq75NAUkIzS7eIlLNABLCFh7L\nPj2x7Q2Fgm5PAXCXd57oehfA9ErfCEbYP/pUE3FQLCvzhEKbBK8UanVlAoGBAIkk\nPEedmB9dMwKir/ROHsDRsD7JSgf2NK3QHumJ9ey5uFC+iIoGyX3uSfwKTBtmoz5J\nZR2f8AQFMoFr8iTS+4IY9TdPGKQvBr8H0qb0gO6eHz0sHPay0W0MVdsBqk7hcdi4\nKd375RFvsLAg6uR2qxsMFgelSlCpZA20hB9JbQAJAoGAEmCK/A7k4AJq0cWtad3y\n9wmUsvGFZUhqj1nYtZ2GchKWIcszM28G77AnT52vPNjSDfygQAVxQ7NSYIcwULiA\nMHL4pB8RQr6P4yXISh7dPG8dlrhefrm4KdVMZPOz0Cpry4KejYWKx/YMjqZxARDd\nZFRtycZMdS8kBvSHeyc4mH8=\n-----END PRIVATE KEY-----\n",
|
6 |
+
"client_email": "firebase-adminsdk-74lss@final-year-project-443dd.iam.gserviceaccount.com",
|
7 |
+
"client_id": "104174452867915111710",
|
8 |
+
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
9 |
+
"token_uri": "https://oauth2.googleapis.com/token",
|
10 |
+
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
|
11 |
+
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/firebase-adminsdk-74lss%40final-year-project-443dd.iam.gserviceaccount.com",
|
12 |
+
"universe_domain": "googleapis.com"
|
13 |
+
}
|
predict.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Prediction interface for Cog ⚙️
|
2 |
+
# https://github.com/replicate/cog/blob/main/docs/python.md
|
3 |
+
|
4 |
+
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
import argparse
|
8 |
+
import random
|
9 |
+
from omegaconf import OmegaConf
|
10 |
+
from einops import rearrange, repeat
|
11 |
+
import torch
|
12 |
+
import torchvision
|
13 |
+
from pytorch_lightning import seed_everything
|
14 |
+
from cog import BasePredictor, Input, Path
|
15 |
+
|
16 |
+
sys.path.insert(0, "scripts/evaluation")
|
17 |
+
from funcs import (
|
18 |
+
batch_ddim_sampling,
|
19 |
+
load_model_checkpoint,
|
20 |
+
load_image_batch,
|
21 |
+
get_filelist,
|
22 |
+
)
|
23 |
+
from utils.utils import instantiate_from_config
|
24 |
+
|
25 |
+
|
26 |
+
class Predictor(BasePredictor):
|
27 |
+
def setup(self) -> None:
|
28 |
+
"""Load the model into memory to make running multiple predictions efficient"""
|
29 |
+
|
30 |
+
ckpt_path_base = "checkpoints/base_1024_v1/model.ckpt"
|
31 |
+
config_base = "configs/inference_t2v_1024_v1.0.yaml"
|
32 |
+
ckpt_path_i2v = "checkpoints/i2v_512_v1/model.ckpt"
|
33 |
+
config_i2v = "configs/inference_i2v_512_v1.0.yaml"
|
34 |
+
|
35 |
+
config_base = OmegaConf.load(config_base)
|
36 |
+
model_config_base = config_base.pop("model", OmegaConf.create())
|
37 |
+
self.model_base = instantiate_from_config(model_config_base)
|
38 |
+
self.model_base = self.model_base.cuda()
|
39 |
+
self.model_base = load_model_checkpoint(self.model_base, ckpt_path_base)
|
40 |
+
self.model_base.eval()
|
41 |
+
|
42 |
+
config_i2v = OmegaConf.load(config_i2v)
|
43 |
+
model_config_i2v = config_i2v.pop("model", OmegaConf.create())
|
44 |
+
self.model_i2v = instantiate_from_config(model_config_i2v)
|
45 |
+
self.model_i2v = self.model_i2v.cuda()
|
46 |
+
self.model_i2v = load_model_checkpoint(self.model_i2v, ckpt_path_i2v)
|
47 |
+
self.model_i2v.eval()
|
48 |
+
|
49 |
+
def predict(
|
50 |
+
self,
|
51 |
+
task: str = Input(
|
52 |
+
description="Choose the task.",
|
53 |
+
choices=["text2video", "image2video"],
|
54 |
+
default="text2video",
|
55 |
+
),
|
56 |
+
prompt: str = Input(
|
57 |
+
description="Prompt for video generation.",
|
58 |
+
default="A tiger walks in the forest, photorealistic, 4k, high definition.",
|
59 |
+
),
|
60 |
+
image: Path = Input(
|
61 |
+
description="Input image for image2video task.", default=None
|
62 |
+
),
|
63 |
+
ddim_steps: int = Input(description="Number of denoising steps.", default=50),
|
64 |
+
unconditional_guidance_scale: float = Input(
|
65 |
+
description="Classifier-free guidance scale.", default=12.0
|
66 |
+
),
|
67 |
+
seed: int = Input(
|
68 |
+
description="Random seed. Leave blank to randomize the seed", default=None
|
69 |
+
),
|
70 |
+
save_fps: int = Input(
|
71 |
+
description="Frame per second for the generated video.", default=10
|
72 |
+
),
|
73 |
+
) -> Path:
|
74 |
+
|
75 |
+
width = 1024 if task == "text2video" else 512
|
76 |
+
height = 576 if task == "text2video" else 320
|
77 |
+
model = self.model_base if task == "text2video" else self.model_i2v
|
78 |
+
|
79 |
+
if task == "image2video":
|
80 |
+
assert image is not None, "Please provide image for image2video generation."
|
81 |
+
|
82 |
+
if seed is None:
|
83 |
+
seed = int.from_bytes(os.urandom(2), "big")
|
84 |
+
print(f"Using seed: {seed}")
|
85 |
+
seed_everything(seed)
|
86 |
+
|
87 |
+
args = argparse.Namespace(
|
88 |
+
mode="base" if task == "text2video" else "i2v",
|
89 |
+
savefps=save_fps,
|
90 |
+
n_samples=1,
|
91 |
+
ddim_steps=ddim_steps,
|
92 |
+
ddim_eta=1.0,
|
93 |
+
bs=1,
|
94 |
+
height=height,
|
95 |
+
width=width,
|
96 |
+
frames=-1,
|
97 |
+
fps=28 if task == "text2video" else 8,
|
98 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
99 |
+
unconditional_guidance_scale_temporal=None,
|
100 |
+
)
|
101 |
+
|
102 |
+
## latent noise shape
|
103 |
+
h, w = args.height // 8, args.width // 8
|
104 |
+
frames = model.temporal_length if args.frames < 0 else args.frames
|
105 |
+
channels = model.channels
|
106 |
+
|
107 |
+
batch_size = 1
|
108 |
+
noise_shape = [batch_size, channels, frames, h, w]
|
109 |
+
fps = torch.tensor([args.fps] * batch_size).to(model.device).long()
|
110 |
+
prompts = [prompt]
|
111 |
+
text_emb = model.get_learned_conditioning(prompts)
|
112 |
+
|
113 |
+
if args.mode == "base":
|
114 |
+
cond = {"c_crossattn": [text_emb], "fps": fps}
|
115 |
+
elif args.mode == "i2v":
|
116 |
+
cond_images = load_image_batch([str(image)], (args.height, args.width))
|
117 |
+
cond_images = cond_images.to(model.device)
|
118 |
+
img_emb = model.get_image_embeds(cond_images)
|
119 |
+
imtext_cond = torch.cat([text_emb, img_emb], dim=1)
|
120 |
+
cond = {"c_crossattn": [imtext_cond], "fps": fps}
|
121 |
+
else:
|
122 |
+
raise NotImplementedError
|
123 |
+
|
124 |
+
## inference
|
125 |
+
batch_samples = batch_ddim_sampling(
|
126 |
+
model,
|
127 |
+
cond,
|
128 |
+
noise_shape,
|
129 |
+
args.n_samples,
|
130 |
+
args.ddim_steps,
|
131 |
+
args.ddim_eta,
|
132 |
+
args.unconditional_guidance_scale,
|
133 |
+
)
|
134 |
+
|
135 |
+
out_path = "/tmp/output.mp4"
|
136 |
+
vid_tensor = batch_samples[0]
|
137 |
+
video = vid_tensor.detach().cpu()
|
138 |
+
video = torch.clamp(video.float(), -1.0, 1.0)
|
139 |
+
video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
|
140 |
+
|
141 |
+
frame_grids = [
|
142 |
+
torchvision.utils.make_grid(framesheet, nrow=int(args.n_samples))
|
143 |
+
for framesheet in video
|
144 |
+
] # [3, 1*h, n*w]
|
145 |
+
grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w]
|
146 |
+
grid = (grid + 1.0) / 2.0
|
147 |
+
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
|
148 |
+
torchvision.io.write_video(
|
149 |
+
out_path,
|
150 |
+
grid,
|
151 |
+
fps=args.savefps,
|
152 |
+
video_codec="h264",
|
153 |
+
options={"crf": "10"},
|
154 |
+
)
|
155 |
+
return Path(out_path)
|
requirements.txt
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
decord==0.6.0
|
2 |
+
einops==0.3.0
|
3 |
+
imageio==2.9.0
|
4 |
+
numpy==1.24.2
|
5 |
+
omegaconf==2.1.1
|
6 |
+
opencv_python>=4.1.2
|
7 |
+
pandas==2.0.0
|
8 |
+
Pillow==9.5.0
|
9 |
+
pytorch_lightning==1.8.3
|
10 |
+
PyYAML==6.0
|
11 |
+
setuptools==65.6.3
|
12 |
+
torch==2.0.0
|
13 |
+
torchvision>=0.7.0
|
14 |
+
tqdm==4.65.0
|
15 |
+
transformers==4.25.1
|
16 |
+
moviepy>=1.0.3
|
17 |
+
av
|
18 |
+
xformers
|
19 |
+
gradio
|
20 |
+
timm
|
21 |
+
scikit-learn
|
22 |
+
open_clip_torch==2.22.0
|
23 |
+
kornia
|
24 |
+
sk-video>=1.1.10
|