savtadepth / src /code /training.py
Dean
Successfully configured the dataloader and trained for one epoch. Results are not so good, but it's something. Still the Fastaiv1 looked better qualitatively
13f0309
raw
history blame
No virus
1.58 kB
import torch
import sys
from fastai2.vision.all import *
from torchvision.utils import save_image
class ImageImageDataLoaders(DataLoaders):
"Basic wrapper around several `DataLoader`s with factory methods for Image to Image problems"
@classmethod
@delegates(DataLoaders.from_dblock)
def from_label_func(cls, path, fnames, label_func, valid_pct=0.2, seed=None, item_tfms=None, batch_tfms=None, **kwargs):
"Create from list of `fnames` in `path`s with `label_func`."
dblock = DataBlock(blocks=(ImageBlock(cls=PILImage), ImageBlock(cls=PILImageBW)),
splitter=RandomSplitter(valid_pct, seed=seed),
get_y=label_func,
item_tfms=item_tfms,
batch_tfms=batch_tfms)
res = cls.from_dblock(dblock, fnames, path=path, **kwargs)
return res
def get_y_fn(x):
y = str(x.absolute()).replace('.jpg', '_depth.png')
y = Path(y)
return y
def create_data(data_path):
fnames = get_files(data_path/'train', extensions='.jpg')
data = ImageImageDataLoaders.from_label_func(data_path/'train', seed=42, bs=4, num_workers=0, fnames=fnames, label_func=get_y_fn)
return data
if __name__ == "__main__":
if len(sys.argv) < 2:
print("usage: %s <data_path>" % sys.argv[0], file=sys.stderr)
sys.exit(0)
data = create_data(Path(sys.argv[1]))
learner = unet_learner(data, resnet34, metrics=rmse, wd=1e-2, n_out=3, loss_func=MSELossFlat(), path='src/')
learner.fine_tune(1)
learner.save('model')