|
import inspect |
|
import os |
|
|
|
import cv2 |
|
import hydra |
|
import numpy as np |
|
import pytorch_lightning as pl |
|
import torch |
|
import torch.nn.functional as F |
|
import torch.utils.checkpoint |
|
from PIL import Image |
|
from diffusers import AutoencoderKL, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler, DDIMScheduler |
|
from omegaconf import DictConfig |
|
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR |
|
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor |
|
from pytorch_lightning.loggers import TensorBoardLogger |
|
from pytorch_lightning.strategies import DDPStrategy |
|
from torch import nn |
|
from torch.utils.data import DataLoader |
|
from torchvision import transforms |
|
from transformers import CLIPTokenizer, CLIPTextModel |
|
|
|
from fid_utils import calculate_fid_given_features |
|
from lora_diffusion import monkeypatch_or_replace_lora, tune_lora_scale |
|
|
|
from models.blip_override.blip import blip_feature_extractor, init_tokenizer |
|
from models.diffusers_override.unet_2d_condition import UNet2DConditionModel |
|
from models.inception import InceptionV3 |
|
unet_target_replace_module = {"CrossAttention", "Attention", "GEGLU"} |
|
|
|
from transformers import CLIPProcessor |
|
import transformers |
|
from PIL import Image |
|
import PIL.Image |
|
import numpy as np |
|
import torchvision.transforms as tvtrans |
|
import requests |
|
from io import BytesIO |
|
|
|
class LightningDataset(pl.LightningDataModule): |
|
def __init__(self, args: DictConfig): |
|
super(LightningDataset, self).__init__() |
|
self.kwargs = {"num_workers": args.num_workers, "persistent_workers": True if args.num_workers > 0 else False, |
|
"pin_memory": True} |
|
self.args = args |
|
|
|
def setup(self, stage="fit"): |
|
if self.args.dataset == "pororo": |
|
import datasets.pororo as data |
|
elif self.args.dataset == 'flintstones': |
|
import datasets.flintstones as data |
|
elif self.args.dataset == 'vistsis': |
|
import datasets.vistsis as data |
|
elif self.args.dataset == 'vistdii': |
|
import datasets.vistdii as data |
|
else: |
|
raise ValueError("Unknown dataset: {}".format(self.args.dataset)) |
|
if stage == "fit": |
|
self.train_data = data.StoryDataset("train", self.args) |
|
self.val_data = data.StoryDataset("val", self.args) |
|
if stage == "test": |
|
self.test_data = data.StoryDataset("test", self.args) |
|
|
|
def train_dataloader(self): |
|
if not hasattr(self, 'trainloader'): |
|
self.trainloader = DataLoader(self.train_data, batch_size=self.args.batch_size, shuffle=True, **self.kwargs) |
|
return self.trainloader |
|
|
|
def val_dataloader(self): |
|
return DataLoader(self.val_data, batch_size=self.args.batch_size, shuffle=False, **self.kwargs) |
|
|
|
def test_dataloader(self): |
|
return DataLoader(self.test_data, batch_size=self.args.batch_size, shuffle=False, **self.kwargs) |
|
|
|
def predict_dataloader(self): |
|
return DataLoader(self.test_data, batch_size=self.args.batch_size, shuffle=False, **self.kwargs) |
|
|
|
def get_length_of_train_dataloader(self): |
|
if not hasattr(self, 'trainloader'): |
|
self.trainloader = DataLoader(self.train_data, batch_size=self.args.batch_size, shuffle=True, **self.kwargs) |
|
return len(self.trainloader) |
|
|
|
|
|
class ARLDM(pl.LightningModule): |
|
def __init__(self, args: DictConfig, steps_per_epoch=1): |
|
super(ARLDM, self).__init__() |
|
self.args = args |
|
self.steps_per_epoch = steps_per_epoch |
|
""" |
|
Configurations |
|
""" |
|
self.task = args.task |
|
|
|
if args.mode == 'sample': |
|
if args.scheduler == "pndm": |
|
self.scheduler = PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", |
|
skip_prk_steps=True) |
|
elif args.scheduler == "ddim": |
|
self.scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", |
|
clip_sample=False, set_alpha_to_one=True) |
|
else: |
|
raise ValueError("Scheduler not supported") |
|
self.fid_augment = transforms.Compose([ |
|
transforms.Resize([64, 64]), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), |
|
]) |
|
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048] |
|
self.inception = InceptionV3([block_idx]) |
|
|
|
self.clip_tokenizer = CLIPTokenizer.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="tokenizer") |
|
|
|
|
|
self.blip_tokenizer = init_tokenizer() |
|
self.blip_image_processor = transforms.Compose([ |
|
transforms.Resize([224, 224]), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711]) |
|
]) |
|
self.max_length = args.get(args.dataset).max_length |
|
|
|
blip_image_null_token = self.blip_image_processor( |
|
Image.fromarray(np.zeros((224, 224, 3), dtype=np.uint8))).unsqueeze(0).float() |
|
clip_text_null_token = self.clip_tokenizer([""], padding="max_length", max_length=self.max_length, |
|
return_tensors="pt").input_ids |
|
blip_text_null_token = self.blip_tokenizer([""], padding="max_length", max_length=self.max_length, |
|
return_tensors="pt").input_ids |
|
|
|
self.register_buffer('clip_text_null_token', clip_text_null_token) |
|
self.register_buffer('blip_text_null_token', blip_text_null_token) |
|
self.register_buffer('blip_image_null_token', blip_image_null_token) |
|
|
|
self.text_encoder = CLIPTextModel.from_pretrained('runwayml/stable-diffusion-v1-5', |
|
subfolder="text_encoder") |
|
|
|
|
|
self.text_encoder.resize_token_embeddings(args.get(args.dataset).clip_embedding_tokens) |
|
|
|
old_embeddings = self.text_encoder.text_model.embeddings.position_embedding |
|
new_embeddings = self.text_encoder._get_resized_embeddings(old_embeddings, self.max_length) |
|
self.text_encoder.text_model.embeddings.position_embedding = new_embeddings |
|
self.text_encoder.config.max_position_embeddings = self.max_length |
|
self.text_encoder.max_position_embeddings = self.max_length |
|
self.text_encoder.text_model.embeddings.position_ids = torch.arange(self.max_length).expand((1, -1)) |
|
|
|
self.modal_type_embeddings = nn.Embedding(2, 768) |
|
self.time_embeddings = nn.Embedding(5, 768) |
|
self.mm_encoder = blip_feature_extractor( |
|
|
|
pretrained='/root/lihui/StoryVisualization/save_pretrained/model_large.pth', |
|
image_size=224, vit='large') |
|
self.mm_encoder.text_encoder.resize_token_embeddings(args.get(args.dataset).blip_embedding_tokens) |
|
|
|
self.vae = AutoencoderKL.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="vae") |
|
self.unet = UNet2DConditionModel.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="unet") |
|
|
|
self.noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", |
|
num_train_timesteps=1000) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.freeze_params(self.vae.parameters()) |
|
if args.freeze_resnet: |
|
self.freeze_params([p for n, p in self.unet.named_parameters() if "attentions" not in n]) |
|
|
|
if args.freeze_blip and hasattr(self, "mm_encoder"): |
|
self.freeze_params(self.mm_encoder.parameters()) |
|
self.unfreeze_params(self.mm_encoder.text_encoder.embeddings.word_embeddings.parameters()) |
|
|
|
if args.freeze_clip and hasattr(self, "text_encoder"): |
|
self.freeze_params(self.text_encoder.parameters()) |
|
self.unfreeze_params(self.text_encoder.text_model.embeddings.token_embedding.parameters()) |
|
|
|
@staticmethod |
|
def freeze_params(params): |
|
for param in params: |
|
param.requires_grad = False |
|
|
|
@staticmethod |
|
def unfreeze_params(params): |
|
for param in params: |
|
param.requires_grad = True |
|
|
|
def configure_optimizers(self): |
|
optimizer = torch.optim.AdamW(self.parameters(), lr=self.args.init_lr, weight_decay=1e-4) |
|
scheduler = LinearWarmupCosineAnnealingLR(optimizer, |
|
warmup_epochs=self.args.warmup_epochs * self.steps_per_epoch, |
|
max_epochs=self.args.max_epochs * self.steps_per_epoch) |
|
optim_dict = { |
|
'optimizer': optimizer, |
|
'lr_scheduler': { |
|
'scheduler': scheduler, |
|
'interval': 'step', |
|
} |
|
} |
|
return optim_dict |
|
|
|
def forward(self, batch): |
|
if self.args.freeze_clip and hasattr(self, "text_encoder"): |
|
self.text_encoder.eval() |
|
if self.args.freeze_blip and hasattr(self, "mm_encoder"): |
|
self.mm_encoder.eval() |
|
images, captions, attention_mask, source_images, source_caption, source_attention_mask, texts, ori_images = batch |
|
B, V, S = captions.shape |
|
src_V = V + 1 if self.task == 'continuation' else V |
|
images = torch.flatten(images, 0, 1) |
|
captions = torch.flatten(captions, 0, 1) |
|
attention_mask = torch.flatten(attention_mask, 0, 1) |
|
source_images = torch.flatten(source_images, 0, 1) |
|
source_caption = torch.flatten(source_caption, 0, 1) |
|
source_attention_mask = torch.flatten(source_attention_mask, 0, 1) |
|
|
|
|
|
classifier_free_idx = np.random.rand(B * V) < 0.1 |
|
|
|
caption_embeddings = self.text_encoder(captions, attention_mask).last_hidden_state |
|
source_embeddings = self.mm_encoder(source_images, source_caption, source_attention_mask, |
|
mode='multimodal').reshape(B, src_V * S, -1) |
|
source_embeddings = source_embeddings.repeat_interleave(V, dim=0) |
|
caption_embeddings[classifier_free_idx] = \ |
|
self.text_encoder(self.clip_text_null_token).last_hidden_state[0] |
|
source_embeddings[classifier_free_idx] = \ |
|
self.mm_encoder(self.blip_image_null_token, self.blip_text_null_token, attention_mask=None, |
|
mode='multimodal')[0].repeat(src_V, 1) |
|
caption_embeddings += self.modal_type_embeddings(torch.tensor(0, device=self.device)) |
|
source_embeddings += self.modal_type_embeddings(torch.tensor(1, device=self.device)) |
|
source_embeddings += self.time_embeddings( |
|
torch.arange(src_V, device=self.device).repeat_interleave(S, dim=0)) |
|
encoder_hidden_states = torch.cat([caption_embeddings, source_embeddings], dim=1) |
|
|
|
attention_mask = torch.cat( |
|
[attention_mask, source_attention_mask.reshape(B, src_V * S).repeat_interleave(V, dim=0)], dim=1) |
|
attention_mask = ~(attention_mask.bool()) |
|
attention_mask[classifier_free_idx] = False |
|
|
|
|
|
square_mask = torch.triu(torch.ones((V, V), device=self.device)).bool() |
|
square_mask = square_mask.unsqueeze(0).unsqueeze(-1).expand(B, V, V, S) |
|
square_mask = square_mask.reshape(B * V, V * S) |
|
attention_mask[:, -V * S:] = torch.logical_or(square_mask, attention_mask[:, -V * S:]) |
|
|
|
latents = self.vae.encode(images).latent_dist.sample() |
|
latents = latents * 0.18215 |
|
|
|
noise = torch.randn(latents.shape, device=self.device) |
|
bsz = latents.shape[0] |
|
timesteps = torch.randint(0, self.noise_scheduler.num_train_timesteps, (bsz,), device=self.device).long() |
|
noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) |
|
|
|
noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states, attention_mask).sample |
|
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() |
|
return loss |
|
|
|
def sample(self, batch): |
|
original_images, captions, attention_mask, source_images, source_caption, source_attention_mask, texts, ori_test_images = batch |
|
B, V, S = captions.shape |
|
src_V = V + 1 if self.task == 'continuation' else V |
|
original_images = torch.flatten(original_images, 0, 1) |
|
captions = torch.flatten(captions, 0, 1) |
|
attention_mask = torch.flatten(attention_mask, 0, 1) |
|
source_images = torch.flatten(source_images, 0, 1) |
|
source_caption = torch.flatten(source_caption, 0, 1) |
|
source_attention_mask = torch.flatten(source_attention_mask, 0, 1) |
|
|
|
caption_embeddings = self.text_encoder(captions, attention_mask).last_hidden_state |
|
source_embeddings = self.mm_encoder(source_images, source_caption, source_attention_mask, |
|
mode='multimodal').reshape(B, src_V * S, -1) |
|
caption_embeddings += self.modal_type_embeddings(torch.tensor(0, device=self.device)) |
|
source_embeddings += self.modal_type_embeddings(torch.tensor(1, device=self.device)) |
|
source_embeddings += self.time_embeddings( |
|
torch.arange(src_V, device=self.device).repeat_interleave(S, dim=0)) |
|
source_embeddings = source_embeddings.repeat_interleave(V, dim=0) |
|
encoder_hidden_states = torch.cat([caption_embeddings, source_embeddings], dim=1) |
|
|
|
attention_mask = torch.cat( |
|
[attention_mask, source_attention_mask.reshape(B, src_V * S).repeat_interleave(V, dim=0)], dim=1) |
|
attention_mask = ~(attention_mask.bool()) |
|
|
|
square_mask = torch.triu(torch.ones((V, V), device=self.device)).bool() |
|
square_mask = square_mask.unsqueeze(0).unsqueeze(-1).expand(B, V, V, S) |
|
square_mask = square_mask.reshape(B * V, V * S) |
|
attention_mask[:, -V * S:] = torch.logical_or(square_mask, attention_mask[:, -V * S:]) |
|
|
|
uncond_caption_embeddings = self.text_encoder(self.clip_text_null_token).last_hidden_state |
|
uncond_source_embeddings = self.mm_encoder(self.blip_image_null_token, self.blip_text_null_token, |
|
attention_mask=None, mode='multimodal').repeat(1, src_V, 1) |
|
uncond_caption_embeddings += self.modal_type_embeddings(torch.tensor(0, device=self.device)) |
|
uncond_source_embeddings += self.modal_type_embeddings(torch.tensor(1, device=self.device)) |
|
uncond_source_embeddings += self.time_embeddings( |
|
torch.arange(src_V, device=self.device).repeat_interleave(S, dim=0)) |
|
uncond_embeddings = torch.cat([uncond_caption_embeddings, uncond_source_embeddings], dim=1) |
|
uncond_embeddings = uncond_embeddings.expand(B * V, -1, -1) |
|
|
|
encoder_hidden_states = torch.cat([uncond_embeddings, encoder_hidden_states]) |
|
uncond_attention_mask = torch.zeros((B * V, (src_V + 1) * S), device=self.device).bool() |
|
uncond_attention_mask[:, -V * S:] = square_mask |
|
attention_mask = torch.cat([uncond_attention_mask, attention_mask], dim=0) |
|
|
|
attention_mask = attention_mask.reshape(2, B, V, (src_V + 1) * S) |
|
images = list() |
|
for i in range(V): |
|
encoder_hidden_states = encoder_hidden_states.reshape(2, B, V, (src_V + 1) * S, -1) |
|
new_image = self.diffusion(encoder_hidden_states[:, :, i].reshape(2 * B, (src_V + 1) * S, -1), |
|
attention_mask[:, :, i].reshape(2 * B, (src_V + 1) * S), |
|
512, 512, self.args.num_inference_steps, self.args.guidance_scale, 0.0) |
|
images += new_image |
|
|
|
new_image = torch.stack([self.blip_image_processor(im) for im in new_image]).to(self.device) |
|
|
|
new_embedding = self.mm_encoder(new_image, |
|
source_caption.reshape(B, src_V, S)[:, i + src_V - V], |
|
source_attention_mask.reshape(B, src_V, S)[:, i + src_V - V], |
|
mode='multimodal') |
|
new_embedding = new_embedding.repeat_interleave(V, dim=0) |
|
new_embedding += self.modal_type_embeddings(torch.tensor(1, device=self.device)) |
|
new_embedding += self.time_embeddings(torch.tensor(i + src_V - V, device=self.device)) |
|
|
|
encoder_hidden_states = encoder_hidden_states[1].reshape(B * V, (src_V + 1) * S, -1) |
|
encoder_hidden_states[:, (i + 1 + src_V - V) * S:(i + 2 + src_V - V) * S] = new_embedding |
|
encoder_hidden_states = torch.cat([uncond_embeddings, encoder_hidden_states]) |
|
|
|
return original_images, images, texts, ori_test_images |
|
|
|
|
|
def training_step(self, batch, batch_idx): |
|
loss = self(batch) |
|
self.log('loss/train_loss', loss, on_step=True, on_epoch=False, sync_dist=True, prog_bar=True) |
|
return loss |
|
|
|
def validation_step(self, batch, batch_idx): |
|
loss = self(batch) |
|
self.log('loss/val_loss', loss, on_step=False, on_epoch=True, sync_dist=True, prog_bar=True) |
|
|
|
def predict_step(self, batch, batch_idx, dataloader_idx=0): |
|
original_images, images, texts, ori_test_images = self.sample(batch) |
|
if self.args.calculate_fid: |
|
original_images = original_images.cpu().numpy().astype('uint8') |
|
original_images = [Image.fromarray(im, 'RGB') for im in original_images] |
|
|
|
|
|
|
|
ori = self.inception_feature(original_images).cpu().numpy() |
|
gen = self.inception_feature(images).cpu().numpy() |
|
else: |
|
ori = None |
|
gen = None |
|
|
|
return images, ori, gen, ori_test_images, texts |
|
|
|
def diffusion(self, encoder_hidden_states, attention_mask, height, width, num_inference_steps, guidance_scale, eta): |
|
latents = torch.randn((encoder_hidden_states.shape[0] // 2, self.unet.in_channels, height // 8, width // 8), |
|
device=self.device) |
|
|
|
|
|
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) |
|
extra_set_kwargs = {} |
|
if accepts_offset: |
|
extra_set_kwargs["offset"] = 1 |
|
|
|
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) |
|
|
|
|
|
if isinstance(self.scheduler, LMSDiscreteScheduler): |
|
latents = latents * self.scheduler.sigmas[0] |
|
|
|
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) |
|
extra_step_kwargs = {} |
|
if accepts_eta: |
|
extra_step_kwargs["eta"] = eta |
|
|
|
for i, t in enumerate(self.scheduler.timesteps): |
|
|
|
latent_model_input = torch.cat([latents] * 2) |
|
|
|
|
|
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states, attention_mask).sample |
|
|
|
|
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
|
|
|
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample |
|
|
|
|
|
latents = 1 / 0.18215 * latents |
|
image = self.vae.decode(latents).sample |
|
|
|
image = (image / 2 + 0.5).clamp(0, 1) |
|
image = image.cpu().permute(0, 2, 3, 1).numpy() |
|
|
|
return self.numpy_to_pil(image) |
|
|
|
@staticmethod |
|
def numpy_to_pil(images): |
|
""" |
|
Convert a numpy image or a batch of images to a PIL image. |
|
""" |
|
if images.ndim == 3: |
|
images = images[None, ...] |
|
images = (images * 255).round().astype("uint8") |
|
pil_images = [Image.fromarray(image, 'RGB') for image in images] |
|
|
|
return pil_images |
|
|
|
def inception_feature(self, images): |
|
images = torch.stack([self.fid_augment(image) for image in images]) |
|
images = images.type(torch.FloatTensor).to(self.device) |
|
images = (images + 1) / 2 |
|
images = F.interpolate(images, size=(299, 299), mode='bilinear', align_corners=False) |
|
pred = self.inception(images)[0] |
|
|
|
if pred.shape[2] != 1 or pred.shape[3] != 1: |
|
pred = F.adaptive_avg_pool2d(pred, output_size=(1, 1)) |
|
return pred.reshape(-1, 2048) |
|
|
|
|
|
def train(args: DictConfig) -> None: |
|
dataloader = LightningDataset(args) |
|
dataloader.setup('fit') |
|
|
|
model = ARLDM(args, steps_per_epoch=dataloader.get_length_of_train_dataloader()) |
|
|
|
logger = TensorBoardLogger(save_dir=os.path.join(args.ckpt_dir, args.run_name), name='log', default_hp_metric=False) |
|
|
|
checkpoint_callback = ModelCheckpoint( |
|
dirpath=os.path.join(args.ckpt_dir, args.run_name), |
|
save_top_k=0, |
|
save_last=True |
|
) |
|
|
|
lr_monitor = LearningRateMonitor(logging_interval='step') |
|
|
|
callback_list = [lr_monitor, checkpoint_callback] |
|
|
|
trainer = pl.Trainer( |
|
accelerator='gpu', |
|
devices=args.gpu_ids, |
|
max_epochs=args.max_epochs, |
|
benchmark=True, |
|
logger=logger, |
|
log_every_n_steps=1, |
|
callbacks=callback_list, |
|
strategy=DDPStrategy(find_unused_parameters=False) |
|
) |
|
trainer.fit(model, dataloader, ckpt_path=args.train_model_file) |
|
|
|
|
|
def sample(args: DictConfig) -> None: |
|
|
|
assert args.test_model_file is not None, "test_model_file cannot be None" |
|
assert args.gpu_ids == 1 or len(args.gpu_ids) == 1, "Only one GPU is supported in test mode" |
|
dataloader = LightningDataset(args) |
|
dataloader.setup('test') |
|
model = ARLDM.load_from_checkpoint(args.test_model_file, args=args, strict=False) |
|
|
|
predictor = pl.Trainer( |
|
accelerator='gpu', |
|
devices=args.gpu_ids, |
|
max_epochs=-1, |
|
benchmark=True |
|
) |
|
predictions = predictor.predict(model, dataloader) |
|
images = [elem for sublist in predictions for elem in sublist[0]] |
|
ori_images = [elem for sublist in predictions for elem in sublist[3]] |
|
ori_test_images = list() |
|
if not os.path.exists(args.sample_output_dir): |
|
try: |
|
os.mkdir(args.sample_output_dir) |
|
except: |
|
pass |
|
|
|
text_list = [elem for sublist in predictions for elem in sublist[4]] |
|
|
|
|
|
num_images = len(images) |
|
num_groups = (num_images + 4) // 5 |
|
|
|
for g in range(num_groups): |
|
print('Story {}:'.format(g + 1)) |
|
start_index = g * 5 |
|
end_index = min(start_index + 5, num_images) |
|
for i in range(start_index, end_index): |
|
print(text_list[i]) |
|
images[i].save( |
|
os.path.join(args.sample_output_dir, 'group{:02d}_image{:02d}.png'.format(g + 1, i - start_index + 1))) |
|
|
|
ori_images_pil = Image.fromarray(np.uint8(ori_images[i].detach().cpu().squeeze().float().numpy())).convert("RGB") |
|
ori_test_images.append(ori_images_pil) |
|
ori_images_pil.save( |
|
os.path.join('/root/lihui/StoryVisualization/ori_test_images_epoch10', 'group{:02d}_image{:02d}.png'.format(g + 1, i - start_index + 1))) |
|
|
|
|
|
|
|
|
|
|
|
if args.calculate_fid: |
|
ori = np.array([elem for sublist in predictions for elem in sublist[1]]) |
|
gen = np.array([elem for sublist in predictions for elem in sublist[2]]) |
|
fid = calculate_fid_given_features(ori, gen) |
|
print('FID: {}'.format(fid)) |
|
|
|
|
|
|
|
|
|
|
|
@hydra.main(config_path=".", config_name="config") |
|
def main(args: DictConfig) -> None: |
|
pl.seed_everything(args.seed) |
|
if args.num_cpu_cores > 0: |
|
torch.set_num_threads(args.num_cpu_cores) |
|
|
|
if args.mode == 'train': |
|
|
|
train(args) |
|
elif args.mode == 'sample': |
|
|
|
|
|
sample(args) |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|