import os import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as tvtrans from einops import rearrange import pytorch_lightning as pl from . import get_model from ..cfg_helper import model_cfg_bank from ..common.utils import regularize_image, regularize_video, remove_duplicate_word import warnings warnings.filterwarnings("ignore") class dani_model(pl.LightningModule): def __init__(self, model='thesis_model', load_weights=True, data_dir='pretrained', pth=["CoDi_encoders.pth"], fp16=False): super().__init__() # import torch # device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") cfgm = model_cfg_bank()(model) net = get_model()(cfgm) if load_weights: for path in pth: net.load_state_dict(torch.load(os.path.join(data_dir, path), map_location='cpu'), strict=False) print('Load pretrained weight from {}'.format(pth)) self.net = net from core.models.ddim.ddim_vd import DDIMSampler_VD self.sampler = DDIMSampler_VD(net) def decode(self, z, xtype): device = z.device net = self.net z = z.to(device) if xtype == 'image': x = net.autokl_decode(z) x = torch.clamp((x + 1.0) / 2.0, min=0.0, max=1.0) return x elif xtype == 'video': num_frames = z.shape[2] z = rearrange(z, 'b c f h w -> (b f) c h w') x = net.autokl_decode(z) x = rearrange(x, '(b f) c h w -> b f c h w', f=num_frames) x = torch.clamp((x + 1.0) / 2.0, min=0.0, max=1.0) video_list = [] for video in x: video_list.append([tvtrans.ToPILImage()(xi) for xi in video]) return video_list elif xtype == 'text': prompt_temperature = 1.0 prompt_merge_same_adj_word = True x = net.optimus_decode(z, temperature=prompt_temperature) """ if prompt_merge_same_adj_word: xnew = [] for xi in x: xi_split = xi.split() xinew = [] for idxi, wi in enumerate(xi_split): if idxi!=0 and wi==xi_split[idxi-1]: continue xinew.append(wi) xnew.append(remove_duplicate_word(' '.join(xinew))) x = xnew """ return x elif xtype == 'audio': x = net.audioldm_decode(z) x = net.mel_spectrogram_to_waveform(x) return x def forward(self, xtype=[], condition=[], condition_types=[], n_samples=1, mix_weight={'video': 1, 'audio': 1, 'text': 1, 'image': 1}, image_size=256, ddim_steps=50, scale=7.5, num_frames=8): # import torch # device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") device = self.device net = self.net sampler = self.sampler ddim_eta = 0.0 conditioning = [] assert len(set(condition_types)) == len(condition_types), "we don't support condition with same modalities yet." assert len(condition) == len(condition_types) for i, condition_type in enumerate(condition_types): if condition_type == 'image': print(condition[i].shape) ctemp1 = regularize_image(condition[i]).squeeze().to(device) print(ctemp1.shape) ctemp1 = ctemp1[None].repeat(n_samples, 1, 1, 1) cim = net.clip_encode_vision(ctemp1).to(device) uim = None if scale != 1.0: dummy = torch.zeros_like(ctemp1).to(device) uim = net.clip_encode_vision(dummy).to(device) conditioning.append(torch.cat([uim, cim])) elif condition_type == 'video': ctemp1 = regularize_video(condition[i]).to(device) ctemp1 = ctemp1[None].repeat(n_samples, 1, 1, 1, 1) cim = net.clip_encode_vision(ctemp1).to(device) uim = None if scale != 1.0: dummy = torch.zeros_like(ctemp1).to(device) uim = net.clip_encode_vision(dummy).to(device) conditioning.append(torch.cat([uim, cim])) elif condition_type == 'audio': ctemp = condition[i][None].repeat(n_samples, 1, 1) cad = net.clap_encode_audio(ctemp) uad = None if scale != 1.0: dummy = torch.zeros_like(ctemp) uad = net.clap_encode_audio(dummy) conditioning.append(torch.cat([uad, cad])) elif condition_type == 'text': ctx = net.clip_encode_text(n_samples * [condition[i]]).to(device) utx = None if scale != 1.0: utx = net.clip_encode_text(n_samples * [""]).to(device) conditioning.append(torch.cat([utx, ctx])) shapes = [] for xtype_i in xtype: if xtype_i == 'image': h, w = [image_size, image_size] shape = [n_samples, 4, h // 8, w // 8] elif xtype_i == 'video': h, w = [image_size, image_size] shape = [n_samples, 4, num_frames, h // 8, w // 8] elif xtype_i == 'text': n = 768 shape = [n_samples, n] elif xtype_i == 'audio': h, w = [256, 16] shape = [n_samples, 8, h, w] else: raise shapes.append(shape) z, _ = sampler.sample( steps=ddim_steps, shape=shapes, condition=conditioning, unconditional_guidance_scale=scale, xtype=xtype, condition_types=condition_types, eta=ddim_eta, verbose=False, mix_weight=mix_weight) out_all = [] for i, xtype_i in enumerate(xtype): z[i] = z[i].to(device) x_i = self.decode(z[i], xtype_i) out_all.append(x_i) return out_all