Anyou's picture
Upload 11 files
4a6e43e
import inspect
import os
import cv2
import hydra
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from PIL import Image
from diffusers import AutoencoderKL, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler, DDIMScheduler
from omegaconf import DictConfig
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.strategies import DDPStrategy
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from transformers import CLIPTokenizer, CLIPTextModel
from fid_utils import calculate_fid_given_features
from lora_diffusion import monkeypatch_or_replace_lora, tune_lora_scale
from models.blip_override.blip import blip_feature_extractor, init_tokenizer
from models.diffusers_override.unet_2d_condition import UNet2DConditionModel
from models.inception import InceptionV3
unet_target_replace_module = {"CrossAttention", "Attention", "GEGLU"}
#!/usr/bin/env python3
from transformers import CLIPProcessor
import transformers
from PIL import Image
import PIL.Image
import numpy as np
import torchvision.transforms as tvtrans
import requests
from io import BytesIO
class LightningDataset(pl.LightningDataModule):
def __init__(self, args: DictConfig):
super(LightningDataset, self).__init__()
self.kwargs = {"num_workers": args.num_workers, "persistent_workers": True if args.num_workers > 0 else False,
"pin_memory": True}
self.args = args
def setup(self, stage="fit"):
if self.args.dataset == "pororo":
import datasets.pororo as data
elif self.args.dataset == 'flintstones':
import datasets.flintstones as data
elif self.args.dataset == 'vistsis':
import datasets.vistsis as data
elif self.args.dataset == 'vistdii':
import datasets.vistdii as data
else:
raise ValueError("Unknown dataset: {}".format(self.args.dataset))
if stage == "fit":
self.train_data = data.StoryDataset("train", self.args)
self.val_data = data.StoryDataset("val", self.args)
if stage == "test":
self.test_data = data.StoryDataset("test", self.args)
def train_dataloader(self):
if not hasattr(self, 'trainloader'):
self.trainloader = DataLoader(self.train_data, batch_size=self.args.batch_size, shuffle=True, **self.kwargs)
return self.trainloader
def val_dataloader(self):
return DataLoader(self.val_data, batch_size=self.args.batch_size, shuffle=False, **self.kwargs)
def test_dataloader(self):
return DataLoader(self.test_data, batch_size=self.args.batch_size, shuffle=False, **self.kwargs)
def predict_dataloader(self):
return DataLoader(self.test_data, batch_size=self.args.batch_size, shuffle=False, **self.kwargs)
def get_length_of_train_dataloader(self):
if not hasattr(self, 'trainloader'):
self.trainloader = DataLoader(self.train_data, batch_size=self.args.batch_size, shuffle=True, **self.kwargs)
return len(self.trainloader)
class ARLDM(pl.LightningModule):
def __init__(self, args: DictConfig, steps_per_epoch=1):
super(ARLDM, self).__init__()
self.args = args
self.steps_per_epoch = steps_per_epoch
"""
Configurations
"""
self.task = args.task
if args.mode == 'sample':
if args.scheduler == "pndm":
self.scheduler = PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
skip_prk_steps=True)
elif args.scheduler == "ddim":
self.scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
clip_sample=False, set_alpha_to_one=True)
else:
raise ValueError("Scheduler not supported")
self.fid_augment = transforms.Compose([
transforms.Resize([64, 64]),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
self.inception = InceptionV3([block_idx])
self.clip_tokenizer = CLIPTokenizer.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="tokenizer")
##############################
#self.clip_tokenizer.save_pretrained('/root/lihui/StoryVisualization/save_pretrained/tokenizer')
self.blip_tokenizer = init_tokenizer()
self.blip_image_processor = transforms.Compose([
transforms.Resize([224, 224]),
transforms.ToTensor(),
transforms.Normalize([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711])
])
self.max_length = args.get(args.dataset).max_length
blip_image_null_token = self.blip_image_processor(
Image.fromarray(np.zeros((224, 224, 3), dtype=np.uint8))).unsqueeze(0).float()
clip_text_null_token = self.clip_tokenizer([""], padding="max_length", max_length=self.max_length,
return_tensors="pt").input_ids
blip_text_null_token = self.blip_tokenizer([""], padding="max_length", max_length=self.max_length,
return_tensors="pt").input_ids
self.register_buffer('clip_text_null_token', clip_text_null_token)
self.register_buffer('blip_text_null_token', blip_text_null_token)
self.register_buffer('blip_image_null_token', blip_image_null_token)
self.text_encoder = CLIPTextModel.from_pretrained('runwayml/stable-diffusion-v1-5',
subfolder="text_encoder")
############################################
#self.text_encoder.save_pretrained('/root/lihui/StoryVisualization/save_pretrained/text_encoder')
self.text_encoder.resize_token_embeddings(args.get(args.dataset).clip_embedding_tokens)
# resize_position_embeddings
old_embeddings = self.text_encoder.text_model.embeddings.position_embedding
new_embeddings = self.text_encoder._get_resized_embeddings(old_embeddings, self.max_length)
self.text_encoder.text_model.embeddings.position_embedding = new_embeddings
self.text_encoder.config.max_position_embeddings = self.max_length
self.text_encoder.max_position_embeddings = self.max_length
self.text_encoder.text_model.embeddings.position_ids = torch.arange(self.max_length).expand((1, -1))
self.modal_type_embeddings = nn.Embedding(2, 768)
self.time_embeddings = nn.Embedding(5, 768)
self.mm_encoder = blip_feature_extractor(
# pretrained='https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth',
pretrained='/root/lihui/StoryVisualization/save_pretrained/model_large.pth',
image_size=224, vit='large')#, local_files_only=True)
self.mm_encoder.text_encoder.resize_token_embeddings(args.get(args.dataset).blip_embedding_tokens)
self.vae = AutoencoderKL.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="vae")
self.unet = UNet2DConditionModel.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="unet")
self.noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
num_train_timesteps=1000)
# monkeypatch_or_replace_lora(
# self.unet,
# torch.load("lora/example_loras/analog_svd_rank4.safetensors"),
# r=4,
# target_replace_module=unet_target_replace_module,
# )
#
# tune_lora_scale(self.unet, 1.00)
#tune_lora_scale(self.text_encoder, 1.00)
# torch.manual_seed(0)
###################################
#self.vae.save_pretrained('/root/lihui/StoryVisualization/save_pretrained/vae')
#self.unet.save_pretrained('/root/lihui/StoryVisualization/save_pretrained/unet')
# Freeze vae and unet
self.freeze_params(self.vae.parameters())
if args.freeze_resnet:
self.freeze_params([p for n, p in self.unet.named_parameters() if "attentions" not in n])
if args.freeze_blip and hasattr(self, "mm_encoder"):
self.freeze_params(self.mm_encoder.parameters())
self.unfreeze_params(self.mm_encoder.text_encoder.embeddings.word_embeddings.parameters())
if args.freeze_clip and hasattr(self, "text_encoder"):
self.freeze_params(self.text_encoder.parameters())
self.unfreeze_params(self.text_encoder.text_model.embeddings.token_embedding.parameters())
@staticmethod
def freeze_params(params):
for param in params:
param.requires_grad = False
@staticmethod
def unfreeze_params(params):
for param in params:
param.requires_grad = True
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.parameters(), lr=self.args.init_lr, weight_decay=1e-4) # optim_bits=8
scheduler = LinearWarmupCosineAnnealingLR(optimizer,
warmup_epochs=self.args.warmup_epochs * self.steps_per_epoch,
max_epochs=self.args.max_epochs * self.steps_per_epoch)
optim_dict = {
'optimizer': optimizer,
'lr_scheduler': {
'scheduler': scheduler, # The LR scheduler instance (required)
'interval': 'step', # The unit of the scheduler's step size
}
}
return optim_dict
def forward(self, batch):
if self.args.freeze_clip and hasattr(self, "text_encoder"):
self.text_encoder.eval()
if self.args.freeze_blip and hasattr(self, "mm_encoder"):
self.mm_encoder.eval()
images, captions, attention_mask, source_images, source_caption, source_attention_mask, texts, ori_images = batch
B, V, S = captions.shape
src_V = V + 1 if self.task == 'continuation' else V
images = torch.flatten(images, 0, 1)
captions = torch.flatten(captions, 0, 1)
attention_mask = torch.flatten(attention_mask, 0, 1)
source_images = torch.flatten(source_images, 0, 1)
source_caption = torch.flatten(source_caption, 0, 1)
source_attention_mask = torch.flatten(source_attention_mask, 0, 1)
# 1 is not masked, 0 is maske
classifier_free_idx = np.random.rand(B * V) < 0.1
caption_embeddings = self.text_encoder(captions, attention_mask).last_hidden_state # B * V, S, D
source_embeddings = self.mm_encoder(source_images, source_caption, source_attention_mask,
mode='multimodal').reshape(B, src_V * S, -1)
source_embeddings = source_embeddings.repeat_interleave(V, dim=0)
caption_embeddings[classifier_free_idx] = \
self.text_encoder(self.clip_text_null_token).last_hidden_state[0]
source_embeddings[classifier_free_idx] = \
self.mm_encoder(self.blip_image_null_token, self.blip_text_null_token, attention_mask=None,
mode='multimodal')[0].repeat(src_V, 1)
caption_embeddings += self.modal_type_embeddings(torch.tensor(0, device=self.device))
source_embeddings += self.modal_type_embeddings(torch.tensor(1, device=self.device))
source_embeddings += self.time_embeddings(
torch.arange(src_V, device=self.device).repeat_interleave(S, dim=0))
encoder_hidden_states = torch.cat([caption_embeddings, source_embeddings], dim=1)
attention_mask = torch.cat(
[attention_mask, source_attention_mask.reshape(B, src_V * S).repeat_interleave(V, dim=0)], dim=1)
attention_mask = ~(attention_mask.bool()) # B * V, (src_V + 1) * S
attention_mask[classifier_free_idx] = False
# B, V, V, S
square_mask = torch.triu(torch.ones((V, V), device=self.device)).bool()
square_mask = square_mask.unsqueeze(0).unsqueeze(-1).expand(B, V, V, S)
square_mask = square_mask.reshape(B * V, V * S)
attention_mask[:, -V * S:] = torch.logical_or(square_mask, attention_mask[:, -V * S:])
latents = self.vae.encode(images).latent_dist.sample()
latents = latents * 0.18215
noise = torch.randn(latents.shape, device=self.device)
bsz = latents.shape[0]
timesteps = torch.randint(0, self.noise_scheduler.num_train_timesteps, (bsz,), device=self.device).long()
noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states, attention_mask).sample
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
return loss
def sample(self, batch):
original_images, captions, attention_mask, source_images, source_caption, source_attention_mask, texts, ori_test_images = batch
B, V, S = captions.shape
src_V = V + 1 if self.task == 'continuation' else V
original_images = torch.flatten(original_images, 0, 1)
captions = torch.flatten(captions, 0, 1)
attention_mask = torch.flatten(attention_mask, 0, 1)
source_images = torch.flatten(source_images, 0, 1)
source_caption = torch.flatten(source_caption, 0, 1)
source_attention_mask = torch.flatten(source_attention_mask, 0, 1)
caption_embeddings = self.text_encoder(captions, attention_mask).last_hidden_state # B * V, S, D
source_embeddings = self.mm_encoder(source_images, source_caption, source_attention_mask,
mode='multimodal').reshape(B, src_V * S, -1)
caption_embeddings += self.modal_type_embeddings(torch.tensor(0, device=self.device))
source_embeddings += self.modal_type_embeddings(torch.tensor(1, device=self.device))
source_embeddings += self.time_embeddings(
torch.arange(src_V, device=self.device).repeat_interleave(S, dim=0))
source_embeddings = source_embeddings.repeat_interleave(V, dim=0)
encoder_hidden_states = torch.cat([caption_embeddings, source_embeddings], dim=1)
attention_mask = torch.cat(
[attention_mask, source_attention_mask.reshape(B, src_V * S).repeat_interleave(V, dim=0)], dim=1)
attention_mask = ~(attention_mask.bool()) # B * V, (src_V + 1) * S
# B, V, V, S
square_mask = torch.triu(torch.ones((V, V), device=self.device)).bool()
square_mask = square_mask.unsqueeze(0).unsqueeze(-1).expand(B, V, V, S)
square_mask = square_mask.reshape(B * V, V * S)
attention_mask[:, -V * S:] = torch.logical_or(square_mask, attention_mask[:, -V * S:])
uncond_caption_embeddings = self.text_encoder(self.clip_text_null_token).last_hidden_state
uncond_source_embeddings = self.mm_encoder(self.blip_image_null_token, self.blip_text_null_token,
attention_mask=None, mode='multimodal').repeat(1, src_V, 1)
uncond_caption_embeddings += self.modal_type_embeddings(torch.tensor(0, device=self.device))
uncond_source_embeddings += self.modal_type_embeddings(torch.tensor(1, device=self.device))
uncond_source_embeddings += self.time_embeddings(
torch.arange(src_V, device=self.device).repeat_interleave(S, dim=0))
uncond_embeddings = torch.cat([uncond_caption_embeddings, uncond_source_embeddings], dim=1)
uncond_embeddings = uncond_embeddings.expand(B * V, -1, -1)
encoder_hidden_states = torch.cat([uncond_embeddings, encoder_hidden_states])
uncond_attention_mask = torch.zeros((B * V, (src_V + 1) * S), device=self.device).bool()
uncond_attention_mask[:, -V * S:] = square_mask
attention_mask = torch.cat([uncond_attention_mask, attention_mask], dim=0)
attention_mask = attention_mask.reshape(2, B, V, (src_V + 1) * S)
images = list()
for i in range(V):
encoder_hidden_states = encoder_hidden_states.reshape(2, B, V, (src_V + 1) * S, -1)
new_image = self.diffusion(encoder_hidden_states[:, :, i].reshape(2 * B, (src_V + 1) * S, -1),
attention_mask[:, :, i].reshape(2 * B, (src_V + 1) * S),
512, 512, self.args.num_inference_steps, self.args.guidance_scale, 0.0)
images += new_image
new_image = torch.stack([self.blip_image_processor(im) for im in new_image]).to(self.device)
new_embedding = self.mm_encoder(new_image, # B,C,H,W
source_caption.reshape(B, src_V, S)[:, i + src_V - V],
source_attention_mask.reshape(B, src_V, S)[:, i + src_V - V],
mode='multimodal') # B, S, D
new_embedding = new_embedding.repeat_interleave(V, dim=0)
new_embedding += self.modal_type_embeddings(torch.tensor(1, device=self.device))
new_embedding += self.time_embeddings(torch.tensor(i + src_V - V, device=self.device))
encoder_hidden_states = encoder_hidden_states[1].reshape(B * V, (src_V + 1) * S, -1)
encoder_hidden_states[:, (i + 1 + src_V - V) * S:(i + 2 + src_V - V) * S] = new_embedding
encoder_hidden_states = torch.cat([uncond_embeddings, encoder_hidden_states])
return original_images, images, texts, ori_test_images
def training_step(self, batch, batch_idx):
loss = self(batch)
self.log('loss/train_loss', loss, on_step=True, on_epoch=False, sync_dist=True, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
loss = self(batch)
self.log('loss/val_loss', loss, on_step=False, on_epoch=True, sync_dist=True, prog_bar=True)
def predict_step(self, batch, batch_idx, dataloader_idx=0):
original_images, images, texts, ori_test_images = self.sample(batch)
if self.args.calculate_fid:
original_images = original_images.cpu().numpy().astype('uint8')
original_images = [Image.fromarray(im, 'RGB') for im in original_images]
# ori_test_images = torch.stack(ori_test_images).cpu().numpy().astype('uint8')
# ori_test_images = [Image.fromarray(im, 'RGB') for im in ori_test_images]
ori = self.inception_feature(original_images).cpu().numpy()
gen = self.inception_feature(images).cpu().numpy()
else:
ori = None
gen = None
return images, ori, gen, ori_test_images, texts
def diffusion(self, encoder_hidden_states, attention_mask, height, width, num_inference_steps, guidance_scale, eta):
latents = torch.randn((encoder_hidden_states.shape[0] // 2, self.unet.in_channels, height // 8, width // 8),
device=self.device)
# set timesteps
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
extra_set_kwargs = {}
if accepts_offset:
extra_set_kwargs["offset"] = 1
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
# if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = latents * self.scheduler.sigmas[0]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
for i, t in enumerate(self.scheduler.timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2)
# noise_pred = self.unet(latent_model_input, t, encoder_hidden_states).sample
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states, attention_mask).sample
# perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
return self.numpy_to_pil(image)
@staticmethod
def numpy_to_pil(images):
"""
Convert a numpy image or a batch of images to a PIL image.
"""
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
pil_images = [Image.fromarray(image, 'RGB') for image in images]
return pil_images
def inception_feature(self, images):
images = torch.stack([self.fid_augment(image) for image in images])
images = images.type(torch.FloatTensor).to(self.device)
images = (images + 1) / 2
images = F.interpolate(images, size=(299, 299), mode='bilinear', align_corners=False)
pred = self.inception(images)[0]
if pred.shape[2] != 1 or pred.shape[3] != 1:
pred = F.adaptive_avg_pool2d(pred, output_size=(1, 1))
return pred.reshape(-1, 2048)
def train(args: DictConfig) -> None:
dataloader = LightningDataset(args)
dataloader.setup('fit')
# dataloader.
model = ARLDM(args, steps_per_epoch=dataloader.get_length_of_train_dataloader())
logger = TensorBoardLogger(save_dir=os.path.join(args.ckpt_dir, args.run_name), name='log', default_hp_metric=False)
checkpoint_callback = ModelCheckpoint(
dirpath=os.path.join(args.ckpt_dir, args.run_name),
save_top_k=0,
save_last=True
)
lr_monitor = LearningRateMonitor(logging_interval='step')
callback_list = [lr_monitor, checkpoint_callback]
trainer = pl.Trainer(
accelerator='gpu',
devices=args.gpu_ids,
max_epochs=args.max_epochs,
benchmark=True,
logger=logger,
log_every_n_steps=1,
callbacks=callback_list,
strategy=DDPStrategy(find_unused_parameters=False)
)
trainer.fit(model, dataloader, ckpt_path=args.train_model_file)
def sample(args: DictConfig) -> None:
assert args.test_model_file is not None, "test_model_file cannot be None"
assert args.gpu_ids == 1 or len(args.gpu_ids) == 1, "Only one GPU is supported in test mode"
dataloader = LightningDataset(args)
dataloader.setup('test')
model = ARLDM.load_from_checkpoint(args.test_model_file, args=args, strict=False)
predictor = pl.Trainer(
accelerator='gpu',
devices=args.gpu_ids,
max_epochs=-1,
benchmark=True
)
predictions = predictor.predict(model, dataloader)
images = [elem for sublist in predictions for elem in sublist[0]]
ori_images = [elem for sublist in predictions for elem in sublist[3]]
ori_test_images = list()
if not os.path.exists(args.sample_output_dir):
try:
os.mkdir(args.sample_output_dir)
except:
pass
text_list = [elem for sublist in predictions for elem in sublist[4]]
################################
# print(f"index: {index}")
num_images = len(images)
num_groups = (num_images + 4) // 5 # 计算总共需要的组数
for g in range(num_groups):
print('Story {}:'.format(g + 1)) # 打印组号
start_index = g * 5 # 当前组的起始索引
end_index = min(start_index + 5, num_images) # 当前组的结束索引
for i in range(start_index, end_index):
print(text_list[i]) # 打印对应的文本
images[i].save(
os.path.join(args.sample_output_dir, 'group{:02d}_image{:02d}.png'.format(g + 1, i - start_index + 1)))
# ori_images[i] = ori_images[i]
ori_images_pil = Image.fromarray(np.uint8(ori_images[i].detach().cpu().squeeze().float().numpy())).convert("RGB")
ori_test_images.append(ori_images_pil)
ori_images_pil.save(
os.path.join('/root/lihui/StoryVisualization/ori_test_images_epoch10', 'group{:02d}_image{:02d}.png'.format(g + 1, i - start_index + 1)))
# for i, im in enumerate(ori_images):
# file_path = '/root/lihui/StoryVisualization/ori_test_images/image{}.png'.format(i)
# cv2.imwrite(file_path, im)
if args.calculate_fid:
ori = np.array([elem for sublist in predictions for elem in sublist[1]])
gen = np.array([elem for sublist in predictions for elem in sublist[2]])
fid = calculate_fid_given_features(ori, gen)
print('FID: {}'.format(fid))
@hydra.main(config_path=".", config_name="config")
def main(args: DictConfig) -> None:
pl.seed_everything(args.seed)
if args.num_cpu_cores > 0:
torch.set_num_threads(args.num_cpu_cores)
if args.mode == 'train':
############################
train(args)
elif args.mode == 'sample':
# dataloader = LightningDataset(args)
# dataloader.setup('test')
sample(args)
if __name__ == '__main__':
main()