Spaces:
Runtime error
Runtime error
dkoshman
commited on
Commit
•
11c4819
1
Parent(s):
a31e03c
fixed image transform
Browse files- data_preprocessing.py +2 -2
- train.py +4 -5
data_preprocessing.py
CHANGED
@@ -74,7 +74,7 @@ class RandomizeImageTransform(object):
|
|
74 |
|
75 |
def __init__(self, width, height, random_magnitude):
|
76 |
self.transform = T.Compose((
|
77 |
-
lambda x: x if random_magnitude == 0 else T.ColorJitter(brightness=random_magnitude / 10,
|
78 |
contrast=random_magnitude / 10,
|
79 |
saturation=random_magnitude / 10,
|
80 |
hue=min(0.5, random_magnitude / 10)),
|
@@ -83,7 +83,7 @@ class RandomizeImageTransform(object):
|
|
83 |
T.functional.invert,
|
84 |
T.CenterCrop((height, width)),
|
85 |
torch.Tensor.contiguous,
|
86 |
-
lambda x: x if random_magnitude == 0 else T.RandAugment(magnitude=random_magnitude),
|
87 |
T.ConvertImageDtype(torch.float32)
|
88 |
))
|
89 |
|
|
|
74 |
|
75 |
def __init__(self, width, height, random_magnitude):
|
76 |
self.transform = T.Compose((
|
77 |
+
(lambda x: x) if random_magnitude == 0 else T.ColorJitter(brightness=random_magnitude / 10,
|
78 |
contrast=random_magnitude / 10,
|
79 |
saturation=random_magnitude / 10,
|
80 |
hue=min(0.5, random_magnitude / 10)),
|
|
|
83 |
T.functional.invert,
|
84 |
T.CenterCrop((height, width)),
|
85 |
torch.Tensor.contiguous,
|
86 |
+
(lambda x: x) if random_magnitude == 0 else T.RandAugment(magnitude=random_magnitude),
|
87 |
T.ConvertImageDtype(torch.float32)
|
88 |
))
|
89 |
|
train.py
CHANGED
@@ -13,8 +13,7 @@ import torch
|
|
13 |
|
14 |
|
15 |
def check_setup():
|
16 |
-
|
17 |
-
"Disabling tokenizers parallelism because it can't be used before forking and I didn't bother to figure it out")
|
18 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
19 |
if not os.path.isfile(DATAMODULE_PATH):
|
20 |
print("Generating default datamodule")
|
@@ -107,7 +106,7 @@ def main():
|
|
107 |
callbacks = [LogImageTexCallback(logger, top_k=10, max_length=100),
|
108 |
LearningRateMonitor(logging_interval="step"),
|
109 |
ModelCheckpoint(save_top_k=10,
|
110 |
-
every_n_train_steps=
|
111 |
monitor="val_loss",
|
112 |
mode="min",
|
113 |
filename="img2tex-{epoch:02d}-{val_loss:.2f}")]
|
@@ -135,9 +134,9 @@ def main():
|
|
135 |
trainer.fit(transformer, datamodule=datamodule)
|
136 |
trainer.test(transformer, datamodule=datamodule)
|
137 |
|
138 |
-
if args.log:
|
139 |
transformer = average_checkpoints(model_type=Transformer, checkpoints_dir=trainer.checkpoint_callback.dirpath)
|
140 |
-
transformer_path = os.path.join(RESOURCES, f"{trainer.logger.version}.pt")
|
141 |
transformer.eval()
|
142 |
transformer.freeze()
|
143 |
torch.save(transformer.state_dict(), transformer_path)
|
|
|
13 |
|
14 |
|
15 |
def check_setup():
|
16 |
+
# Disabling tokenizers parallelism because it can't be used before forking and I didn't bother to figure it out
|
|
|
17 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
18 |
if not os.path.isfile(DATAMODULE_PATH):
|
19 |
print("Generating default datamodule")
|
|
|
106 |
callbacks = [LogImageTexCallback(logger, top_k=10, max_length=100),
|
107 |
LearningRateMonitor(logging_interval="step"),
|
108 |
ModelCheckpoint(save_top_k=10,
|
109 |
+
every_n_train_steps=5,
|
110 |
monitor="val_loss",
|
111 |
mode="min",
|
112 |
filename="img2tex-{epoch:02d}-{val_loss:.2f}")]
|
|
|
134 |
trainer.fit(transformer, datamodule=datamodule)
|
135 |
trainer.test(transformer, datamodule=datamodule)
|
136 |
|
137 |
+
if args.log and len(os.listdir(trainer.checkpoint_callback.dirpath)):
|
138 |
transformer = average_checkpoints(model_type=Transformer, checkpoints_dir=trainer.checkpoint_callback.dirpath)
|
139 |
+
transformer_path = os.path.join(RESOURCES, f"model_{trainer.logger.version}.pt")
|
140 |
transformer.eval()
|
141 |
transformer.freeze()
|
142 |
torch.save(transformer.state_dict(), transformer_path)
|