Spaces:
Build error
Build error
| import os | |
| import argparse | |
| import random | |
| import logging | |
| import torch | |
| import wandb | |
| import numpy as np | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import matplotlib.pyplot as plt | |
| import matplotlib.ticker as ticker | |
| from torchvision import transforms | |
| from torch.utils.data import DataLoader | |
| from pathlib import Path | |
| from utils import __balance_val_split, __split_of_train_sequence | |
| from datasets.czech_slr_dataset import CzechSLRDataset | |
| from spoter.spoter_model import SPOTER | |
| from spoter.utils import train_epoch, evaluate | |
| from spoter.gaussian_noise import GaussianNoise | |
| def get_default_args(): | |
| parser = argparse.ArgumentParser(add_help=False) | |
| parser.add_argument("--experiment_name", type=str, default="lsa_64_spoter", | |
| help="Name of the experiment after which the logs and plots will be named") | |
| parser.add_argument("--num_classes", type=int, default=64, help="Number of classes to be recognized by the model") | |
| parser.add_argument("--hidden_dim", type=int, default=108, | |
| help="Hidden dimension of the underlying Transformer model") | |
| parser.add_argument("--seed", type=int, default=379, | |
| help="Seed with which to initialize all the random components of the training") | |
| # Data | |
| parser.add_argument("--training_set_path", type=str, default="", help="Path to the training dataset CSV file") | |
| parser.add_argument("--testing_set_path", type=str, default="", help="Path to the testing dataset CSV file") | |
| parser.add_argument("--experimental_train_split", type=float, default=None, | |
| help="Determines how big a portion of the training set should be employed (intended for the " | |
| "gradually enlarging training set experiment from the paper)") | |
| parser.add_argument("--validation_set", type=str, choices=["from-file", "split-from-train", "none"], | |
| default="from-file", help="Type of validation set construction. See README for further rederence") | |
| parser.add_argument("--validation_set_size", type=float, | |
| help="Proportion of the training set to be split as validation set, if 'validation_size' is set" | |
| " to 'split-from-train'") | |
| parser.add_argument("--validation_set_path", type=str, default="", help="Path to the validation dataset CSV file") | |
| # Training hyperparameters | |
| parser.add_argument("--epochs", type=int, default=100, help="Number of epochs to train the model for") | |
| parser.add_argument("--lr", type=float, default=0.001, help="Learning rate for the model training") | |
| parser.add_argument("--log_freq", type=int, default=1, | |
| help="Log frequency (frequency of printing all the training info)") | |
| # Checkpointing | |
| parser.add_argument("--save_checkpoints", type=bool, default=True, | |
| help="Determines whether to save weights checkpoints") | |
| # Scheduler | |
| parser.add_argument("--scheduler_factor", type=int, default=0.1, help="Factor for the ReduceLROnPlateau scheduler") | |
| parser.add_argument("--scheduler_patience", type=int, default=5, | |
| help="Patience for the ReduceLROnPlateau scheduler") | |
| # Gaussian noise normalization | |
| parser.add_argument("--gaussian_mean", type=int, default=0, help="Mean parameter for Gaussian noise layer") | |
| parser.add_argument("--gaussian_std", type=int, default=0.001, | |
| help="Standard deviation parameter for Gaussian noise layer") | |
| parser.add_argument("--augmentations_probability", type=float, default=0.5, help="") # 0.462 | |
| parser.add_argument("--rotate_angle", type=int, default=17, help="") # 17 | |
| parser.add_argument("--perspective_transform_ratio", type=float, default=0.2, help="") # 0.1682 | |
| parser.add_argument("--squeeze_ratio", type=float, default=0.4, help="") # 0.3971 | |
| parser.add_argument("--arm_joint_rotate_angle", type=int, default=4, help="") # 3 | |
| parser.add_argument("--arm_joint_rotate_probability", type=float, default=0.4, help="") # 0.3596 | |
| # Visualization | |
| parser.add_argument("--plot_stats", type=bool, default=True, | |
| help="Determines whether continuous statistics should be plotted at the end") | |
| parser.add_argument("--plot_lr", type=bool, default=True, | |
| help="Determines whether the LR should be plotted at the end") | |
| # WANDB | |
| parser.add_argument("--wandb_key", type=str, default="", help="") | |
| parser.add_argument("--wandb_entity", type=str, default="", help="") | |
| return parser | |
| def train(args): | |
| if args.wandb_key: | |
| wandb.login(key=args.wandb_key) | |
| wandb.init(project=args.experiment_name, entity=args.wandb_entity) | |
| wandb.config.update(args) | |
| # MARK: TRAINING PREPARATION AND MODULES | |
| args.experiment_name = args.experiment_name + "_lr" + wandb.run.id | |
| # Initialize all the random seeds | |
| random.seed(args.seed) | |
| np.random.seed(args.seed) | |
| os.environ["PYTHONHASHSEED"] = str(args.seed) | |
| torch.manual_seed(args.seed) | |
| torch.cuda.manual_seed(args.seed) | |
| torch.cuda.manual_seed_all(args.seed) | |
| torch.backends.cudnn.deterministic = True | |
| g = torch.Generator() | |
| g.manual_seed(args.seed) | |
| # Set the output format to print into the console and save into LOG file | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(message)s", | |
| handlers=[ | |
| logging.FileHandler(args.experiment_name + "_" + str(args.experimental_train_split).replace(".", "") + ".log") | |
| ] | |
| ) | |
| # Set device to CUDA only if applicable | |
| device = torch.device("cpu") | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| # Construct the model | |
| slrt_model = SPOTER(num_classes=args.num_classes, hidden_dim=args.hidden_dim) | |
| slrt_model.train(True) | |
| slrt_model.to(device) | |
| # Construct the other modules | |
| cel_criterion = nn.CrossEntropyLoss() | |
| sgd_optimizer = optim.SGD(slrt_model.parameters(), lr=args.lr) | |
| scheduler = optim.lr_scheduler.ReduceLROnPlateau(sgd_optimizer, factor=args.scheduler_factor, patience=args.scheduler_patience) | |
| # Ensure that the path for checkpointing and for images both exist | |
| Path("out-checkpoints/" + args.experiment_name + "/").mkdir(parents=True, exist_ok=True) | |
| Path("out-img/").mkdir(parents=True, exist_ok=True) | |
| # MARK: DATA | |
| # Training set | |
| transform = transforms.Compose([GaussianNoise(args.gaussian_mean, args.gaussian_std)]) | |
| augmentations_config = { | |
| "rotate-angle": args.rotate_angle, | |
| "perspective-transform-ratio": args.perspective_transform_ratio, | |
| "squeeze-ratio": args.squeeze_ratio, | |
| "arm-joint-rotate-angle": args.arm_joint_rotate_angle, | |
| "arm-joint-rotate-probability": args.arm_joint_rotate_probability | |
| } | |
| train_set = CzechSLRDataset(args.training_set_path, transform=transform, augmentations=True, | |
| augmentations_prob=args.augmentations_probability, augmentations_config=augmentations_config) | |
| # Validation set | |
| if args.validation_set == "from-file": | |
| val_set = CzechSLRDataset(args.validation_set_path) | |
| val_loader = DataLoader(val_set, shuffle=True, generator=g) | |
| elif args.validation_set == "split-from-train": | |
| train_set, val_set = __balance_val_split(train_set, 0.2) | |
| val_set.transform = None | |
| val_set.augmentations = False | |
| val_loader = DataLoader(val_set, shuffle=True, generator=g) | |
| else: | |
| val_loader = None | |
| # Testing set | |
| if args.testing_set_path: | |
| eval_set = CzechSLRDataset(args.testing_set_path) | |
| eval_loader = DataLoader(eval_set, shuffle=True, generator=g) | |
| else: | |
| eval_loader = None | |
| # Final training set refinements | |
| if args.experimental_train_split: | |
| train_set = __split_of_train_sequence(train_set, args.experimental_train_split) | |
| train_loader = DataLoader(train_set, shuffle=True, generator=g) | |
| # MARK: TRAINING | |
| train_acc, val_acc = 0, 0 | |
| losses, train_accs, val_accs = [], [], [] | |
| lr_progress = [] | |
| top_train_acc, top_val_acc = 0, 0 | |
| checkpoint_index = 0 | |
| if args.experimental_train_split: | |
| print("Starting " + args.experiment_name + "_" + str(args.experimental_train_split).replace(".", "") + "...\n\n") | |
| logging.info("Starting " + args.experiment_name + "_" + str(args.experimental_train_split).replace(".", "") + "...\n\n") | |
| else: | |
| print("Starting " + args.experiment_name + "...\n\n") | |
| logging.info("Starting " + args.experiment_name + "...\n\n") | |
| for epoch in range(args.epochs): | |
| train_loss, _, _, train_acc = train_epoch(slrt_model, train_loader, cel_criterion, sgd_optimizer, device) | |
| losses.append(train_loss.item() / len(train_loader)) | |
| train_accs.append(train_acc) | |
| if val_loader: | |
| slrt_model.train(False) | |
| _, _, val_acc = evaluate(slrt_model, val_loader, device) | |
| slrt_model.train(True) | |
| val_accs.append(val_acc) | |
| # Save checkpoints if they are best in the current subset | |
| if args.save_checkpoints: | |
| if train_acc > top_train_acc: | |
| top_train_acc = train_acc | |
| torch.save(slrt_model, "out-checkpoints/" + args.experiment_name + "/checkpoint_t_" + str(checkpoint_index) + ".pth") | |
| if val_acc > top_val_acc: | |
| top_val_acc = val_acc | |
| torch.save(slrt_model, "out-checkpoints/" + args.experiment_name + "/checkpoint_v_" + str(checkpoint_index) + ".pth") | |
| if epoch % args.log_freq == 0: | |
| print("[" + str(epoch + 1) + "] TRAIN loss: " + str(train_loss.item() / len(train_loader)) + " acc: " + str(train_acc)) | |
| logging.info("[" + str(epoch + 1) + "] TRAIN loss: " + str(train_loss.item() / len(train_loader)) + " acc: " + str(train_acc)) | |
| wandb.log({ | |
| "epoch": int(epoch + 1), | |
| "train-loss": float(train_loss.item() / len(train_loader)), | |
| "train-accuracy": train_acc | |
| }) | |
| if val_loader: | |
| print("[" + str(epoch + 1) + "] VALIDATION acc: " + str(val_acc)) | |
| logging.info("[" + str(epoch + 1) + "] VALIDATION acc: " + str(val_acc)) | |
| if args.wandb_key: | |
| wandb.log({ | |
| "validation-accuracy": val_acc | |
| }) | |
| print("") | |
| logging.info("") | |
| # Reset the top accuracies on static subsets | |
| if epoch % 10 == 0: | |
| top_train_acc, top_val_acc = 0, 0 | |
| checkpoint_index += 1 | |
| lr_progress.append(sgd_optimizer.param_groups[0]["lr"]) | |
| # MARK: TESTING | |
| print("\nTesting checkpointed models starting...\n") | |
| logging.info("\nTesting checkpointed models starting...\n") | |
| top_result, top_result_name = 0, "" | |
| if eval_loader: | |
| for i in range(checkpoint_index): | |
| for checkpoint_id in ["t", "v"]: | |
| # tested_model = VisionTransformer(dim=2, mlp_dim=108, num_classes=100, depth=12, heads=8) | |
| tested_model = torch.load("out-checkpoints/" + args.experiment_name + "/checkpoint_" + checkpoint_id + "_" + str(i) + ".pth") | |
| tested_model.train(False) | |
| _, _, eval_acc = evaluate(tested_model, eval_loader, device, print_stats=True) | |
| if eval_acc > top_result: | |
| top_result = eval_acc | |
| top_result_name = args.experiment_name + "/checkpoint_" + checkpoint_id + "_" + str(i) | |
| print("checkpoint_" + checkpoint_id + "_" + str(i) + " -> " + str(eval_acc)) | |
| logging.info("checkpoint_" + checkpoint_id + "_" + str(i) + " -> " + str(eval_acc)) | |
| print("\nThe top result was recorded at " + str(top_result) + " testing accuracy. The best checkpoint is " + top_result_name + ".") | |
| logging.info("\nThe top result was recorded at " + str(top_result) + " testing accuracy. The best checkpoint is " + top_result_name + ".") | |
| if args.wandb_key: | |
| wandb.run.summary["best-accuracy"] = top_result | |
| wandb.run.summary["best-checkpoint"] = top_result_name | |
| # PLOT 0: Performance (loss, accuracies) chart plotting | |
| if args.plot_stats: | |
| fig, ax = plt.subplots() | |
| ax.plot(range(1, len(losses) + 1), losses, c="#D64436", label="Training loss") | |
| ax.plot(range(1, len(train_accs) + 1), train_accs, c="#00B09B", label="Training accuracy") | |
| if val_loader: | |
| ax.plot(range(1, len(val_accs) + 1), val_accs, c="#E0A938", label="Validation accuracy") | |
| ax.xaxis.set_major_locator(ticker.MaxNLocator(integer=True)) | |
| ax.set(xlabel="Epoch", ylabel="Accuracy / Loss", title="") | |
| plt.legend(loc="upper center", bbox_to_anchor=(0.5, 1.05), ncol=4, fancybox=True, shadow=True, fontsize="xx-small") | |
| ax.grid() | |
| fig.savefig("out-img/" + args.experiment_name + "_loss.png") | |
| # PLOT 1: Learning rate progress | |
| if args.plot_lr: | |
| fig1, ax1 = plt.subplots() | |
| ax1.plot(range(1, len(lr_progress) + 1), lr_progress, label="LR") | |
| ax1.set(xlabel="Epoch", ylabel="LR", title="") | |
| ax1.grid() | |
| fig1.savefig("out-img/" + args.experiment_name + "_lr.png") | |
| print("\nAny desired statistics have been plotted.\nThe experiment is finished.") | |
| logging.info("\nAny desired statistics have been plotted.\nThe experiment is finished.") | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser("", parents=[get_default_args()], add_help=False) | |
| args = parser.parse_args() | |
| train(args) | |