import argparse import random import math import time import os import numpy as np import torch import wandb from data.dataset import TextDataset, CollectionTextDataset from models.model import VATr from util.misc import EpochLossTracker, add_vatr_args, LinearScheduler def main(): parser = argparse.ArgumentParser() parser.add_argument("--resume", action='store_true') parser = add_vatr_args(parser) args = parser.parse_args() rSeed(args.seed) dataset = CollectionTextDataset( args.dataset, 'files', TextDataset, file_suffix=args.file_suffix, num_examples=args.num_examples, collator_resolution=args.resolution, min_virtual_size=339, validation=False, debug=False, height=args.img_height ) datasetval = CollectionTextDataset( args.dataset, 'files', TextDataset, file_suffix=args.file_suffix, num_examples=args.num_examples, collator_resolution=args.resolution, min_virtual_size=161, validation=True, height=args.img_height ) args.num_writers = dataset.num_writers if args.dataset == 'IAM' or args.dataset == 'CVL': args.alphabet = 'Only thewigsofrcvdampbkuq.A-210xT5\'MDL,RYHJ"ISPWENj&BC93VGFKz();#:!7U64Q8?+*ZX/%' else: args.alphabet = ''.join(sorted(set(dataset.alphabet + datasetval.alphabet))) args.special_alphabet = ''.join(c for c in args.special_alphabet if c not in dataset.alphabet) args.exp_name = f"{args.dataset}-{args.num_writers}-{args.num_examples}-LR{args.g_lr}-bs{args.batch_size}-{args.tag}" config = {k: v for k, v in args.__dict__.items() if isinstance(v, (bool, int, str, float))} args.wandb = args.wandb and (not torch.cuda.is_available() or torch.cuda.get_device_name(0) != 'Tesla K80') wandb_id = wandb.util.generate_id() MODEL_PATH = os.path.join(args.save_model_path, args.exp_name) os.makedirs(MODEL_PATH, exist_ok=True) train_loader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True, collate_fn=dataset.collate_fn) val_loader = torch.utils.data.DataLoader( datasetval, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True, collate_fn=datasetval.collate_fn) model = VATr(args) start_epoch = 0 del config['alphabet'] del config['special_alphabet'] wandb_params = { 'project': 'VATr', 'config': config, 'name': args.exp_name, 'id': wandb_id } checkpoint_path = os.path.join(MODEL_PATH, 'model.pth') loss_tracker = EpochLossTracker() if args.resume and os.path.exists(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location=args.device) model.load_state_dict(checkpoint['model']) start_epoch = checkpoint['epoch'] wandb_params['id'] = checkpoint['wandb_id'] wandb_params['resume'] = True print(checkpoint_path + ' : Model loaded Successfully') elif args.resume: raise FileNotFoundError(f'No model found at {checkpoint_path}') else: if args.feat_model_path is not None and args.feat_model_path.lower() != 'none': print('Loading...', args.feat_model_path) assert os.path.exists(args.feat_model_path) checkpoint = torch.load(args.feat_model_path, map_location=args.device) checkpoint['model']['conv1.weight'] = checkpoint['model']['conv1.weight'].mean(1).unsqueeze(1) del checkpoint['model']['fc.weight'] del checkpoint['model']['fc.bias'] miss, unexp = model.netG.Feat_Encoder.load_state_dict(checkpoint['model'], strict=False) if not os.path.isdir(MODEL_PATH): os.mkdir(MODEL_PATH) else: print(f'WARNING: No resume of Resnet-18, starting from scratch') if args.wandb: wandb.init(**wandb_params) wandb.watch(model) print(f"Starting training") for epoch in range(start_epoch, args.epochs): start_time = time.time() log_time = time.time() loss_tracker.reset() model.d_acc.update(0.0) if args.text_augment_strength > 0: model.set_text_aug_strength(args.text_augment_strength) for i, data in enumerate(train_loader): model.update_parameters(epoch) model._set_input(data) model.optimize_G_only() model.optimize_G_step() model.optimize_D_OCR() model.optimize_D_OCR_step() model.optimize_G_WL() model.optimize_G_step() model.optimize_D_WL() model.optimize_D_WL_step() if time.time() - log_time > 10: print( f'Epoch {epoch} {i / len(train_loader) * 100:.02f}% running, current time: {time.time() - start_time:.2f} s') log_time = time.time() batch_losses = model.get_current_losses() batch_losses['d_acc'] = model.d_acc.avg loss_tracker.add_batch(batch_losses) end_time = time.time() data_val = next(iter(val_loader)) losses = loss_tracker.get_epoch_loss() page = model._generate_page(model.sdata, model.input['swids']) page_val = model._generate_page(data_val['simg'].to(args.device), data_val['swids']) d_train, d_val, d_fake = model.compute_d_stats(train_loader, val_loader) if args.wandb: wandb.log({ 'loss-G': losses['G'], 'loss-D': losses['D'], 'loss-Dfake': losses['Dfake'], 'loss-Dreal': losses['Dreal'], 'loss-OCR_fake': losses['OCR_fake'], 'loss-OCR_real': losses['OCR_real'], 'loss-w_fake': losses['w_fake'], 'loss-w_real': losses['w_real'], 'd_acc': losses['d_acc'], 'd-rv': (d_train - d_val) / (d_train - d_fake), 'd-fake': d_fake, 'd-real': d_train, 'd-val': d_val, 'l_cycle': losses['cycle'], 'epoch': epoch, 'timeperepoch': end_time - start_time, 'result': [wandb.Image(page, caption="page"), wandb.Image(page_val, caption="page_val")], 'd-crop-size': model.netD.augmenter.get_current_width() if model.netD.crop else 0 }) print({'EPOCH': epoch, 'TIME': end_time - start_time, 'LOSSES': losses}) print(f"Text sample: {model.get_text_sample(10)}") checkpoint = { 'model': model.state_dict(), 'wandb_id': wandb_id, 'epoch': epoch } if epoch % args.save_model == 0: torch.save(checkpoint, os.path.join(MODEL_PATH, 'model.pth')) if epoch % args.save_model_history == 0: torch.save(checkpoint, os.path.join(MODEL_PATH, f'{epoch:04d}_model.pth')) def rSeed(sd): random.seed(sd) np.random.seed(sd) torch.manual_seed(sd) torch.cuda.manual_seed(sd) if __name__ == "__main__": print("Training Model") main() wandb.finish()