import sys import os import os.path as osp import clip import json import torch from torch import nn import torch.nn.functional as F from torch.optim.lr_scheduler import StepLR from torchvision import transforms import numpy as np import time import shutil from .aux import TrainingStatTracker, update_progress, update_stdout, sec2dhms from .config import SEMANTIC_DIPOLES_CORPORA, STYLEGAN_LAYERS class DataParallelPassthrough(nn.DataParallel): def __getattr__(self, name): try: return super(DataParallelPassthrough, self).__getattr__(name) except AttributeError: return getattr(self.module, name) class Trainer(object): def __init__(self, params=None, exp_dir=None, use_cuda=False, multi_gpu=False): if params is None: raise ValueError("Cannot build a Trainer instance with empty params: params={}".format(params)) else: self.params = params self.use_cuda = use_cuda self.multi_gpu = multi_gpu # Set output directory for current experiment (wip) self.wip_dir = osp.join("experiments", "wip", exp_dir) # Set directory for completed experiment self.complete_dir = osp.join("experiments", "complete", exp_dir) # Create log subdirectory and define stat.json file self.stats_json = osp.join(self.wip_dir, 'stats.json') if not osp.isfile(self.stats_json): with open(self.stats_json, 'w') as out: json.dump({}, out) # Create models sub-directory self.models_dir = osp.join(self.wip_dir, 'models') os.makedirs(self.models_dir, exist_ok=True) # Define checkpoint model file self.checkpoint = osp.join(self.models_dir, 'checkpoint.pt') # Array of iteration times self.iter_times = np.array([]) # Set up training statistics tracker self.stat_tracker = TrainingStatTracker() # Define cosine similarity loss self.cosine_embedding_loss = nn.CosineEmbeddingLoss() # Define cross entropy loss self.cross_entropy_loss = nn.CrossEntropyLoss() # Define transform of CLIP image encoder self.clip_img_transform = transforms.Compose([transforms.Resize(224), transforms.CenterCrop(224), transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))]) def contrastive_loss(self, img_batch, txt_batch): n_img, d_img = img_batch.shape n_txt, d_txt = txt_batch.shape # TODO: assert that dimensions are the same? # Normalise image and text batches img_batch_l2 = F.normalize(img_batch, p=2, dim=-1) txt_batch_l2 = F.normalize(txt_batch, p=2, dim=-1) # Calculate inner product similarity matrix similarity_matrix = torch.matmul(img_batch_l2, txt_batch_l2.T) labels = torch.arange(n_img) return self.cross_entropy_loss(similarity_matrix / self.params.temperature, labels) def get_starting_iteration(self, latent_support_sets): """Check if checkpoint file exists (under `self.models_dir`) and set starting iteration at the checkpoint iteration; also load checkpoint weights to `latent_support_sets`. Otherwise, set starting iteration to 1 in order to train from scratch. Returns: starting_iter (int): starting iteration """ starting_iter = 1 if osp.isfile(self.checkpoint): checkpoint_dict = torch.load(self.checkpoint) starting_iter = checkpoint_dict['iter'] latent_support_sets.load_state_dict(checkpoint_dict['latent_support_sets']) return starting_iter def log_progress(self, iteration, mean_iter_time, elapsed_time, eta): """Log progress (loss + ETA). Args: iteration (int) : current iteration mean_iter_time (float) : mean iteration time elapsed_time (float) : elapsed time until current iteration eta (float) : estimated time of experiment completion """ # Get current training stats (for the previous `self.params.log_freq` steps) and flush them stats = self.stat_tracker.get_means() # Update training statistics json file with open(self.stats_json) as f: stats_dict = json.load(f) stats_dict.update({iteration: stats}) with open(self.stats_json, 'w') as out: json.dump(stats_dict, out) # Flush training statistics tracker self.stat_tracker.flush() update_progress(" \\__.Training [bs: {}] [iter: {:06d}/{:06d}] ".format( self.params.batch_size, iteration, self.params.max_iter), self.params.max_iter, iteration + 1) if iteration < self.params.max_iter - 1: print() print(" ===================================================================") print(" \\__Loss : {:.08f}".format(stats['loss'])) print(" ===================================================================") print(" \\__Mean iter time : {:.3f} sec".format(mean_iter_time)) print(" \\__Elapsed time : {}".format(sec2dhms(elapsed_time))) print(" \\__ETA : {}".format(sec2dhms(eta))) print(" ===================================================================") update_stdout(8) def train(self, generator, latent_support_sets, corpus_support_sets, clip_model): """GANxPlainer training function. Args: generator : non-trainable (pre-trained) GAN generator latent_support_sets : trainable LSS model -- interpretable latent paths model corpus_support_sets : non-trainable CSS model -- non-linear paths in the CLIP space clip_model : non-trainable (pre-trained) CLIP model """ # Save initial `latent_support_sets` model as `latent_support_sets_init.pt` torch.save(latent_support_sets.state_dict(), osp.join(self.models_dir, 'latent_support_sets_init.pt')) # Save initial `corpus_support_sets` model as `corpus_support_sets_init.pt` torch.save(corpus_support_sets.state_dict(), osp.join(self.models_dir, 'corpus_support_sets_init.pt')) # Save prompt corpus list to json with open(osp.join(self.models_dir, 'semantic_dipoles.json'), 'w') as json_f: json.dump(SEMANTIC_DIPOLES_CORPORA[self.params.corpus], json_f) # Upload models to GPU if `self.use_cuda` is set (i.e., if args.cuda and torch.cuda.is_available is True). if self.use_cuda: generator.cuda().eval() clip_model.cuda().eval() corpus_support_sets.cuda() latent_support_sets.cuda().train() else: generator.eval() clip_model.eval() latent_support_sets.train() # Set latent support sets (LSS) optimizer latent_support_sets_optim = torch.optim.Adam(latent_support_sets.parameters(), lr=self.params.lr) # Set learning rate scheduler -- reduce lr after 90% of the total number of training iterations latent_support_sets_lr_scheduler = StepLR(optimizer=latent_support_sets_optim, step_size=int(0.9 * self.params.max_iter), gamma=0.1) # Get starting iteration starting_iter = self.get_starting_iteration(latent_support_sets) # Parallelize models into multiple GPUs, if available and `multi_gpu=True`. if self.multi_gpu: print("#. Parallelize G and CLIP over {} GPUs...".format(torch.cuda.device_count())) # Parallelize generator G generator = DataParallelPassthrough(generator) # Parallelize CLIP model clip_model = DataParallelPassthrough(clip_model) # Check starting iteration if starting_iter == self.params.max_iter: print("#. This experiment has already been completed and can be found @ {}".format(self.wip_dir)) print("#. Copy {} to {}...".format(self.wip_dir, self.complete_dir)) try: shutil.copytree(src=self.wip_dir, dst=self.complete_dir, ignore=shutil.ignore_patterns('checkpoint.pt')) print(" \\__Done!") except IOError as e: print(" \\__Already exists -- {}".format(e)) sys.exit() print("#. Start training from iteration {}".format(starting_iter)) # Get experiment's start time t0 = time.time() # Start training for iteration in range(starting_iter, self.params.max_iter + 1): # Get current iteration's start time iter_t0 = time.time() # Set gradients to zero generator.zero_grad() latent_support_sets.zero_grad() clip_model.zero_grad() # Sample latent codes from standard Gaussian z = torch.randn(self.params.batch_size, generator.dim_z) if self.use_cuda: z = z.cuda() # Generate images for the given latent codes latent_code = z if 'stylegan' in self.params.gan: if self.params.stylegan_space == 'W': latent_code = generator.get_w(z, truncation=self.params.truncation)[:, 0, :] elif self.params.stylegan_space == 'W+': latent_code = generator.get_w(z, truncation=self.params.truncation) img = generator(latent_code) # Sample indices of shift vectors (`self.params.batch_size` out of `self.params.num_support_sets`) # target_support_sets_indices = torch.randint(0, self.params.num_support_sets, [self.params.batch_size]) target_support_sets_indices = torch.randint(0, latent_support_sets.num_support_sets, [self.params.batch_size]) if self.use_cuda: target_support_sets_indices = target_support_sets_indices.cuda() # Sample shift magnitudes from uniform distributions # U[self.params.min_shift_magnitude, self.params.max_shift_magnitude], and # U[-self.params.max_shift_magnitude, self.params.min_shift_magnitude] # Create a pool of shift magnitudes of 2 * `self.params.batch_size` shifts (half negative, half positive) # and sample `self.params.batch_size` of them shift_magnitudes_pos = (self.params.min_shift_magnitude - self.params.max_shift_magnitude) * \ torch.rand(target_support_sets_indices.size()) + self.params.max_shift_magnitude shift_magnitudes_neg = (self.params.min_shift_magnitude - self.params.max_shift_magnitude) * \ torch.rand(target_support_sets_indices.size()) - self.params.min_shift_magnitude shift_magnitudes_pool = torch.cat((shift_magnitudes_neg, shift_magnitudes_pos)) shift_magnitudes_ids = torch.arange(len(shift_magnitudes_pool), dtype=torch.float) target_shift_magnitudes = shift_magnitudes_pool[torch.multinomial(input=shift_magnitudes_ids, num_samples=self.params.batch_size, replacement=False)] if self.use_cuda: target_shift_magnitudes = target_shift_magnitudes.cuda() # Create support sets mask of size (batch_size, num_support_sets) in the form: # support_sets_mask[i] = [0, ..., 0, 1, 0, ..., 0] support_sets_mask = torch.zeros([self.params.batch_size, latent_support_sets.num_support_sets]) prompt_mask = torch.zeros([self.params.batch_size, 2]) prompt_sign = torch.zeros([self.params.batch_size, 1]) if self.use_cuda: support_sets_mask = support_sets_mask.cuda() prompt_mask = prompt_mask.cuda() prompt_sign = prompt_sign.cuda() for i, (index, val) in enumerate(zip(target_support_sets_indices, target_shift_magnitudes)): support_sets_mask[i][index] += 1.0 if val >= 0: prompt_mask[i, 0] = 1.0 prompt_sign[i] = +1.0 else: prompt_mask[i, 1] = 1.0 prompt_sign[i] = -1.0 prompt_mask = prompt_mask.unsqueeze(1) # Calculate shift vectors for the given latent codes -- in the case of StyleGAN, shifts live in the # self.params.stylegan_space, i.e., in Z-, W-, or W+-space. In the Z-/W-space the dimensionality of the # latent space is 512. In the case of W+-space, the dimensionality is 512 * (self.params.stylegan_layer + 1) if ('stylegan' in self.params.gan) and (self.params.stylegan_space == 'W+'): shift = target_shift_magnitudes.reshape(-1, 1) * latent_support_sets( support_sets_mask, latent_code[:, :self.params.stylegan_layer + 1, :].reshape(latent_code.shape[0], -1)) else: shift = target_shift_magnitudes.reshape(-1, 1) * latent_support_sets(support_sets_mask, latent_code) # Generate images the shifted latent codes if ('stylegan' in self.params.gan) and (self.params.stylegan_space == 'W+'): latent_code_reshaped = latent_code.reshape(latent_code.shape[0], -1) shift = F.pad(input=shift, pad=(0, (STYLEGAN_LAYERS[self.params.gan] - 1 - self.params.stylegan_layer) * 512), mode='constant', value=0) latent_code_shifted = latent_code_reshaped + shift latent_code_shifted_reshaped = latent_code_shifted.reshape_as(latent_code) img_shifted = generator(latent_code_shifted_reshaped) else: img_shifted = generator(latent_code + shift) # TODO: add comment img_pairs = torch.cat([self.clip_img_transform(img), self.clip_img_transform(img_shifted)], dim=0) clip_img_pairs_features = clip_model.encode_image(img_pairs) clip_img_features, clip_img_shifted_features = torch.split(clip_img_pairs_features, img.shape[0], dim=0) clip_img_diff_features = clip_img_shifted_features - clip_img_features ############################################################################################################ ## ## ## Linear Text Paths (StyleCLIP approach) ## ## ## ############################################################################################################ if self.params.styleclip: corpus_text_features_batch = torch.matmul(support_sets_mask, corpus_support_sets.SUPPORT_SETS).reshape( -1, 2 * corpus_support_sets.num_support_dipoles, corpus_support_sets.support_vectors_dim) corpus_text_features_batch = torch.matmul(prompt_mask, corpus_text_features_batch).squeeze(1) # Calculate cosine similarity loss if self.params.loss == 'cossim': loss = self.cosine_embedding_loss(clip_img_shifted_features, corpus_text_features_batch, torch.ones(corpus_text_features_batch.shape[0]).to( 'cuda' if self.use_cuda else 'cpu')) # Calculate contrastive loss elif self.params.loss == 'contrastive': loss = self.contrastive_loss(clip_img_shifted_features.float(), corpus_text_features_batch) ############################################################################################################ ## ## ## Linear Text Paths ## ## ## ############################################################################################################ elif self.params.linear: corpus_text_features_batch = torch.matmul(support_sets_mask, corpus_support_sets.SUPPORT_SETS).reshape( -1, 2 * corpus_support_sets.num_support_dipoles, corpus_support_sets.support_vectors_dim) # Calculate cosine similarity loss if self.params.loss == 'cossim': loss = self.cosine_embedding_loss(clip_img_diff_features, prompt_sign * ( corpus_text_features_batch[:, 0, :] - corpus_text_features_batch[:, 1, :]) - clip_img_features, torch.ones(corpus_text_features_batch.shape[0]).to( 'cuda' if self.use_cuda else 'cpu')) # Calculate contrastive loss elif self.params.loss == 'contrastive': loss = self.contrastive_loss(clip_img_diff_features.float(), prompt_sign * ( corpus_text_features_batch[:, 0, :] - corpus_text_features_batch[:, 1, :]) - clip_img_features) ############################################################################################################ ## ## ## Non-linear Text Paths ## ## ## ############################################################################################################ else: # Calculate local text direction using CSS local_text_directions = target_shift_magnitudes.reshape(-1, 1) * corpus_support_sets(support_sets_mask, clip_img_features) # Calculate cosine similarity loss if self.params.loss == 'cossim': loss = self.cosine_embedding_loss(clip_img_diff_features, local_text_directions, torch.ones(local_text_directions.shape[0]).to( 'cuda' if self.use_cuda else 'cpu')) # Calculate contrastive loss elif self.params.loss == 'contrastive': loss = self.contrastive_loss(img_batch=clip_img_diff_features.float(), txt_batch=local_text_directions) # Back-propagate! loss.backward() # Update weights clip_model.float() latent_support_sets_optim.step() latent_support_sets_lr_scheduler.step() clip.model.convert_weights(clip_model) # Update statistics tracker self.stat_tracker.update(loss=loss.item()) # Get time of completion of current iteration iter_t = time.time() # Compute elapsed time for current iteration and append to `iter_times` self.iter_times = np.append(self.iter_times, iter_t - iter_t0) # Compute elapsed time so far elapsed_time = iter_t - t0 # Compute rolling mean iteration time mean_iter_time = self.iter_times.mean() # Compute estimated time of experiment completion eta = elapsed_time * ((self.params.max_iter - iteration) / (iteration - starting_iter + 1)) # Log progress in stdout if iteration % self.params.log_freq == 0: self.log_progress(iteration, mean_iter_time, elapsed_time, eta) # Save checkpoint model file and latent support_sets model state dicts after current iteration if iteration % self.params.ckp_freq == 0: # Build checkpoint dict checkpoint_dict = { 'iter': iteration, 'latent_support_sets': latent_support_sets.state_dict(), } torch.save(checkpoint_dict, self.checkpoint) # === End of training loop === # Get experiment's total elapsed time elapsed_time = time.time() - t0 # Save final latent support sets (LSS) model latent_support_sets_model_filename = osp.join(self.models_dir, 'latent_support_sets.pt') torch.save(latent_support_sets.state_dict(), latent_support_sets_model_filename) for _ in range(10): print() print("#.Training completed -- Total elapsed time: {}.".format(sec2dhms(elapsed_time))) print("#. Copy {} to {}...".format(self.wip_dir, self.complete_dir)) try: shutil.copytree(src=self.wip_dir, dst=self.complete_dir) print(" \\__Done!") except IOError as e: print(" \\__Already exists -- {}".format(e))