File size: 1,662 Bytes
1964059 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
from train import train
from argparse import ArgumentParser
def save_log(file, str):
"""Write brief logs for every training of the pipeline"""
pipeline_log = open(file, "a")
pipeline_log.write(str)
pipeline_log.close()
if __name__ == "__main__":
"""Pipeline which directly call the train function of the train.py file, with the necessary arguments to reproduce the paper results """
parser = ArgumentParser()
parser.add_argument("--model_name")
parser.add_argument("--size")
parser.add_argument("--cropped")
parser.add_argument("--device")
parser.add_argument("--labels")
args = parser.parse_args()
# Pipeline launched for 5 sessions training
for i in range(5):
for label in ["pressure", "wind"]:
for model in ["resnet18", "resnet50"]:
args.model_name, args.size, args.cropped, args.device, args.labels = model, "512", False, 0, label
train_log = train(args)
save_log("pipeline_logs.txt", "training session " + str(i*3) + " : " + str(args) + " " + train_log + "\n")
args.model_name, args.size, args.cropped, args.device, args.labels = model, "224", "False", 0, label
train_log = train(args)
save_log("pipeline_logs.txt", "training session " + str(i*3 +1) + " : " + str(args) + " " + train_log + "\n")
args.model_name, args.size, args.cropped, args.device, args.labels = model, "224", "True", 0, label
train_log = train(args)
save_log("pipeline_logs.txt", "training session " + str(i*3 +2) + " : " + str(args) + " " + train_log + "\n")
|