cloud-classifier / train.py
Morsecode's picture
Added files
3602056
raw
history blame contribute delete
907 Bytes
from fastcore.all import *
from fastai.vision.all import *
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
path = Path('data')
dls = DataBlock(
blocks=(ImageBlock, CategoryBlock),
get_items=get_image_files,
splitter=RandomSplitter(valid_pct=0.2, seed=42),
get_y=parent_label,
item_tfms=[Resize(500, method='squish')]
).new(item_tfms=RandomResizedCrop(128, min_scale=0.3)).new(item_tfms=Resize(250), batch_tfms=aug_transforms(mult=1.1)).dataloaders(path, bs=32)
learn = vision_learner(dls, resnet18, metrics=error_rate)
learnResult = learn.fine_tune(5)
learn.export('model.pkl')
# uncomment the following lines to generate some data debugging information
# interp = ClassificationInterpretation.from_learner(learn)
# interp.plot_confusion_matrix()
# plt.savefig('confusion_matrix.png')
#
# interp.plot_top_losses(12, nrows=4)
# plt.savefig('top_losses.png')