|
import os |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torchvision.transforms as tvtrans |
|
|
|
from einops import rearrange |
|
|
|
import pytorch_lightning as pl |
|
|
|
from . import get_model |
|
from ..cfg_helper import model_cfg_bank |
|
from ..common.utils import regularize_image, regularize_video, remove_duplicate_word |
|
|
|
import warnings |
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
class model_module(pl.LightningModule): |
|
def __init__(self, model='codi', load_weights=True, data_dir='pretrained', pth=["CoDi_encoders.pth"], fp16=False): |
|
super().__init__() |
|
|
|
|
|
cfgm = model_cfg_bank()(model) |
|
net = get_model()(cfgm) |
|
if fp16: |
|
net = net.half() |
|
if load_weights: |
|
for path in pth: |
|
net.load_state_dict(torch.load(os.path.join(data_dir, path), map_location='cpu'), strict=False) |
|
print('Load pretrained weight from {}'.format(pth)) |
|
|
|
self.net = net |
|
|
|
from core.models.ddim.ddim_vd import DDIMSampler_VD |
|
self.sampler = DDIMSampler_VD(net) |
|
|
|
def decode(self, z, xtype): |
|
|
|
|
|
device = z.device |
|
net = self.net |
|
z = z.to(device) |
|
if xtype == 'image': |
|
x = net.autokl_decode(z) |
|
x = torch.clamp((x+1.0)/2.0, min=0.0, max=1.0) |
|
|
|
return x |
|
|
|
elif xtype == 'video': |
|
num_frames = z.shape[2] |
|
z = rearrange(z, 'b c f h w -> (b f) c h w') |
|
x = net.autokl_decode(z) |
|
x = rearrange(x, '(b f) c h w -> b f c h w', f=num_frames) |
|
|
|
x = torch.clamp((x+1.0)/2.0, min=0.0, max=1.0) |
|
video_list = [] |
|
for video in x: |
|
video_list.append([tvtrans.ToPILImage()(xi) for xi in video]) |
|
return video_list |
|
|
|
elif xtype == 'text': |
|
prompt_temperature = 1.0 |
|
prompt_merge_same_adj_word = True |
|
x = net.optimus_decode(z, temperature=prompt_temperature) |
|
""" |
|
if prompt_merge_same_adj_word: |
|
xnew = [] |
|
for xi in x: |
|
xi_split = xi.split() |
|
xinew = [] |
|
for idxi, wi in enumerate(xi_split): |
|
if idxi!=0 and wi==xi_split[idxi-1]: |
|
continue |
|
xinew.append(wi) |
|
xnew.append(remove_duplicate_word(' '.join(xinew))) |
|
x = xnew |
|
""" |
|
return x |
|
|
|
elif xtype == 'audio': |
|
x = net.audioldm_decode(z) |
|
x = net.mel_spectrogram_to_waveform(x) |
|
return x |
|
|
|
def forward(self, xtype=[], condition=[], condition_types=[], n_samples=1, mix_weight={'video': 1, 'audio': 1, 'text': 1, 'image': 1}, image_size=256, ddim_steps=50, scale=7.5, num_frames=8): |
|
|
|
|
|
device = self.device |
|
net = self.net |
|
sampler = self.sampler |
|
ddim_eta = 0.0 |
|
|
|
conditioning = [] |
|
assert len(set(condition_types)) == len(condition_types), "we don't support condition with same modalities yet." |
|
assert len(condition) == len(condition_types) |
|
|
|
for i, condition_type in enumerate(condition_types): |
|
if condition_type == 'image': |
|
print(condition[i].shape) |
|
ctemp1 = regularize_image(condition[i]).squeeze().to(device) |
|
print(ctemp1.shape) |
|
ctemp1 = ctemp1[None].repeat(n_samples, 1, 1, 1) |
|
cim = net.clip_encode_vision(ctemp1).to(device) |
|
uim = None |
|
if scale != 1.0: |
|
dummy = torch.zeros_like(ctemp1).to(device) |
|
uim = net.clip_encode_vision(dummy).to(device) |
|
conditioning.append(torch.cat([uim, cim])) |
|
|
|
elif condition_type == 'video': |
|
ctemp1 = regularize_video(condition[i]).to(device) |
|
ctemp1 = ctemp1[None].repeat(n_samples, 1, 1, 1, 1) |
|
cim = net.clip_encode_vision(ctemp1).to(device) |
|
uim = None |
|
if scale != 1.0: |
|
dummy = torch.zeros_like(ctemp1).to(device) |
|
uim = net.clip_encode_vision(dummy).to(device) |
|
conditioning.append(torch.cat([uim, cim])) |
|
|
|
elif condition_type == 'audio': |
|
ctemp = condition[i][None].repeat(n_samples, 1, 1) |
|
cad = net.clap_encode_audio(ctemp) |
|
uad = None |
|
if scale != 1.0: |
|
dummy = torch.zeros_like(ctemp) |
|
uad = net.clap_encode_audio(dummy) |
|
conditioning.append(torch.cat([uad, cad])) |
|
|
|
elif condition_type == 'text': |
|
ctx = net.clip_encode_text(n_samples * [condition[i]]).to(device) |
|
utx = None |
|
if scale != 1.0: |
|
utx = net.clip_encode_text(n_samples * [""]).to(device) |
|
conditioning.append(torch.cat([utx, ctx])) |
|
|
|
shapes = [] |
|
for xtype_i in xtype: |
|
if xtype_i == 'image': |
|
h, w = [image_size, image_size] |
|
shape = [n_samples, 4, h//8, w//8] |
|
elif xtype_i == 'video': |
|
h, w = [image_size, image_size] |
|
shape = [n_samples, 4, num_frames, h//8, w//8] |
|
elif xtype_i == 'text': |
|
n = 768 |
|
shape = [n_samples, n] |
|
elif xtype_i == 'audio': |
|
h, w = [256, 16] |
|
shape = [n_samples, 8, h, w] |
|
else: |
|
raise |
|
shapes.append(shape) |
|
|
|
z, _ = sampler.sample( |
|
steps=ddim_steps, |
|
shape=shapes, |
|
condition=conditioning, |
|
unconditional_guidance_scale=scale, |
|
xtype=xtype, |
|
condition_types=condition_types, |
|
eta=ddim_eta, |
|
verbose=False, |
|
mix_weight=mix_weight) |
|
|
|
out_all = [] |
|
for i, xtype_i in enumerate(xtype): |
|
z[i] = z[i].to(device) |
|
x_i = self.decode(z[i], xtype_i) |
|
out_all.append(x_i) |
|
return out_all |
|
|