dkoshman commited on
Commit
11c4819
1 Parent(s): a31e03c

fixed image transform

Browse files
Files changed (2) hide show
  1. data_preprocessing.py +2 -2
  2. train.py +4 -5
data_preprocessing.py CHANGED
@@ -74,7 +74,7 @@ class RandomizeImageTransform(object):
74
 
75
  def __init__(self, width, height, random_magnitude):
76
  self.transform = T.Compose((
77
- lambda x: x if random_magnitude == 0 else T.ColorJitter(brightness=random_magnitude / 10,
78
  contrast=random_magnitude / 10,
79
  saturation=random_magnitude / 10,
80
  hue=min(0.5, random_magnitude / 10)),
@@ -83,7 +83,7 @@ class RandomizeImageTransform(object):
83
  T.functional.invert,
84
  T.CenterCrop((height, width)),
85
  torch.Tensor.contiguous,
86
- lambda x: x if random_magnitude == 0 else T.RandAugment(magnitude=random_magnitude),
87
  T.ConvertImageDtype(torch.float32)
88
  ))
89
 
 
74
 
75
  def __init__(self, width, height, random_magnitude):
76
  self.transform = T.Compose((
77
+ (lambda x: x) if random_magnitude == 0 else T.ColorJitter(brightness=random_magnitude / 10,
78
  contrast=random_magnitude / 10,
79
  saturation=random_magnitude / 10,
80
  hue=min(0.5, random_magnitude / 10)),
 
83
  T.functional.invert,
84
  T.CenterCrop((height, width)),
85
  torch.Tensor.contiguous,
86
+ (lambda x: x) if random_magnitude == 0 else T.RandAugment(magnitude=random_magnitude),
87
  T.ConvertImageDtype(torch.float32)
88
  ))
89
 
train.py CHANGED
@@ -13,8 +13,7 @@ import torch
13
 
14
 
15
  def check_setup():
16
- print(
17
- "Disabling tokenizers parallelism because it can't be used before forking and I didn't bother to figure it out")
18
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
19
  if not os.path.isfile(DATAMODULE_PATH):
20
  print("Generating default datamodule")
@@ -107,7 +106,7 @@ def main():
107
  callbacks = [LogImageTexCallback(logger, top_k=10, max_length=100),
108
  LearningRateMonitor(logging_interval="step"),
109
  ModelCheckpoint(save_top_k=10,
110
- every_n_train_steps=500,
111
  monitor="val_loss",
112
  mode="min",
113
  filename="img2tex-{epoch:02d}-{val_loss:.2f}")]
@@ -135,9 +134,9 @@ def main():
135
  trainer.fit(transformer, datamodule=datamodule)
136
  trainer.test(transformer, datamodule=datamodule)
137
 
138
- if args.log:
139
  transformer = average_checkpoints(model_type=Transformer, checkpoints_dir=trainer.checkpoint_callback.dirpath)
140
- transformer_path = os.path.join(RESOURCES, f"{trainer.logger.version}.pt")
141
  transformer.eval()
142
  transformer.freeze()
143
  torch.save(transformer.state_dict(), transformer_path)
 
13
 
14
 
15
  def check_setup():
16
+ # Disabling tokenizers parallelism because it can't be used before forking and I didn't bother to figure it out
 
17
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
18
  if not os.path.isfile(DATAMODULE_PATH):
19
  print("Generating default datamodule")
 
106
  callbacks = [LogImageTexCallback(logger, top_k=10, max_length=100),
107
  LearningRateMonitor(logging_interval="step"),
108
  ModelCheckpoint(save_top_k=10,
109
+ every_n_train_steps=5,
110
  monitor="val_loss",
111
  mode="min",
112
  filename="img2tex-{epoch:02d}-{val_loss:.2f}")]
 
134
  trainer.fit(transformer, datamodule=datamodule)
135
  trainer.test(transformer, datamodule=datamodule)
136
 
137
+ if args.log and len(os.listdir(trainer.checkpoint_callback.dirpath)):
138
  transformer = average_checkpoints(model_type=Transformer, checkpoints_dir=trainer.checkpoint_callback.dirpath)
139
+ transformer_path = os.path.join(RESOURCES, f"model_{trainer.logger.version}.pt")
140
  transformer.eval()
141
  transformer.freeze()
142
  torch.save(transformer.state_dict(), transformer_path)