Spaces:
Paused
Paused
File size: 1,645 Bytes
3c0c5aa 068408a eeb74de 3c0c5aa 13f0309 3c0c5aa eeb74de 13f0309 3c0c5aa 79fd7d0 3c0c5aa eeb74de 9c03436 79fd7d0 9c03436 79fd7d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
import torch
import sys
from fastai.vision.all import *
from torchvision.utils import save_image
class ImageImageDataLoaders(DataLoaders):
"Basic wrapper around several `DataLoader`s with factory methods for Image to Image problems"
@classmethod
@delegates(DataLoaders.from_dblock)
def from_label_func(cls, path, fnames, label_func, valid_pct=0.2, seed=None, item_tfms=None, batch_tfms=None, **kwargs):
"Create from list of `fnames` in `path`s with `label_func`."
dblock = DataBlock(blocks=(ImageBlock(cls=PILImage), ImageBlock(cls=PILImageBW)),
splitter=RandomSplitter(valid_pct, seed=seed),
get_y=label_func,
item_tfms=item_tfms,
batch_tfms=batch_tfms)
res = cls.from_dblock(dblock, fnames, path=path, **kwargs)
return res
def get_y_fn(x):
y = str(x.absolute()).replace('.jpg', '_depth.png')
y = Path(y)
return y
def create_data(data_path):
fnames = get_files(data_path/'train', extensions='.jpg')
data = ImageImageDataLoaders.from_label_func(data_path/'train', seed=42, bs=4, num_workers=0, fnames=fnames, label_func=get_y_fn)
return data
if __name__ == "__main__":
if len(sys.argv) < 2:
print("usage: %s <data_path>" % sys.argv[0], file=sys.stderr)
sys.exit(0)
data = create_data(Path(sys.argv[1]))
learner = unet_learner(data, resnet34, metrics=rmse, wd=1e-2, n_out=3, loss_func=MSELossFlat(), path='src/test/')
print("Training model...")
learner.fine_tune(1)
print("Saving model...")
learner.save('model')
|