# -*- coding: utf-8 -*- """20221206.ipynb Automatically generated by Colaboratory. Original file is located at https://colab.research.google.com/drive/1QJCfDvr9ofVasBT4JZbaZHr44QYCMuB0 ###1.1 安裝套件(若在colab訓練每次都需要執行) """ !pip install fastbook -q """###1.2 讀取套件""" from fastbook import * from fastai.vision.widgets import * divice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") divice import fastai; print('fastai version:',fastai.__version__) print('torch version:',torch.__version__) """###1.3 準備資料集""" from google.colab import drive drive.mount('/content/drive') path = Path('/content/drive/MyDrive/dataset/mycat2') keywords = {'Lion': 'Lion','Tiger':'Tiger','Snow leopard':'Snow leopard'} array = keywords.items() if not path.exists(): !mkdir -p /content/drive/MyDrive/dataset/mycat2 for key,value in array: print(key,value) dest = (path/key) dest.mkdir(exist_ok=True) urls = search_images_ddg(f'{value}',max_images=300) download_images (dest,urls=urls) """###1.4 清洗資料""" fns = get_image_files(path) failed = verify_images(fns) failed.map(Path.unlink) #unlink broken images """###2.1. 設定訓練資料路徑""" path = Path('/content/drive/MyDrive/dataset/mycat2') path.ls() #建立模型銓重儲存路徑 myPath = '/content/drive/MyDrive/dataset/models' !mkdir -p $myPath """###2.2. 資料讀取框架""" dataset = DataBlock( blocks=(ImageBlock,CategoryBlock), get_items = get_image_files, splitter = RandomSplitter(valid_pct=0.2,seed=42), item_tfms = Resize(224), get_y = parent_label ) #利用框架正式讀取資料 dls = dataset.dataloaders(path,bs=16,num_workers=16) #讀取結果 print(dls.c,dls.vocab,len(dls.train_ds),len(dls.valid_ds)) print('訓練資料') dls.show_batch(max_n=5, nrows=1,unique=True) dls.show_batch(max_n=5,nrows=1) """###3.1. 選擇模型架構以及對應的預訓練權重 ###Note: metrics是模型訓練人員觀察的指標, 可設定多個 """ learn = vision_learner(dls, resnet34, metrics=[accuracy, error_rate], pretrained=True) learn.fit_one_cycle(3, 1e-3) """###3.2 儲存第一次訓練好的權重""" myModel=myPath+'/resnet34_stage-1.pkl' learn.export(myModel) """###3.3 解凍權重再次訓練""" learn.unfreeze() lr_min,lr_steep = learn.lr_find(suggest_funcs=(minimum, steep)) print(f"Minimum/10: {lr_min:.2e}, steepest point: {lr_steep:.2e}") learn.fit_one_cycle(6, lr_max= 2.75e-04) """###3.4 儲存新的權重""" myModel=myPath+'/resnet34_stage-2.pkl' learn.export(myModel) """###4.1. 結果檢核(Confusion Matrix)""" # Contains interpretation methods for classification models interp = ClassificationInterpretation.from_learner(learn) # Plot the confusion matrix interp.plot_confusion_matrix() interp.plot_top_losses(5, nrows=1) """###4.2 ROC Curve and AUC""" preds,y, loss = learn.get_preds(with_loss=True) # get accuracy acc = accuracy(preds, y) print('The accuracy is {0} %.'.format(acc)) from sklearn.metrics import roc_curve, auc # probs from log preds probs = np.exp(preds[:,1]) # Compute ROC curve fpr, tpr, thresholds = roc_curve(y, probs, pos_label=1) # Compute ROC area roc_auc = auc(fpr, tpr) print('ROC area is {0}'.format(roc_auc)) plt.figure() plt.plot(fpr, tpr, color='darkorange', label='ROC curve (area = %0.2f)' % roc_auc) plt.plot([0, 1], [0, 1], color='navy', linestyle='--') plt.xlim([-0.01, 1.01]) plt.ylim([-0.01, 1.01]) plt.axis('square') plt.xlabel('False Positive Rate') plt.ylabel('True Positive Rate') plt.title('Receiver operating characteristic') plt.legend(loc="lower right") """###5. 預測""" ## 模型位置 from fastbook import * from fastai.vision.widgets import * device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device """###5.1 讀取先前訓練好的權重""" myPath='/content/drive/MyDrive/dataset/models' myModel=myPath+'/resnet34_stage-2.pkl' learn = load_learner(myModel) """###5.2 讀取檔案並送入模型預測""" ## 執行預測 - method I fnames_Lion = get_image_files('/content/drive/MyDrive/dataset/mycat2/Lion') fnames_Tiger = get_image_files('/content/drive/MyDrive/dataset/mycat2/Tiger') fnames_Snow_leopard = get_image_files('/content/drive/MyDrive/dataset/mycat2/Snow leopard') fnames_Lion pred_class,pred_idx,outputs = learn.predict(fnames_Lion[3]) print("Actual: Lion, Predicted = {}".format(pred_class)) pred_class,pred_idx,outputs = learn.predict(fnames_Tiger[7]) print("Actual: Tiger, Predicted = {}".format(pred_class)) """###6. Visualization with Grad-CAM""" class Hook(): def __init__(self, m): self.hook = m.register_forward_hook(self.hook_func) def hook_func(self, m, i, o): self.stored = o.detach().clone() def __enter__(self, *args): return self def __exit__(self, *args): self.hook.remove() class HookBwd(): def __init__(self, m): self.hook = m.register_backward_hook(self.hook_func) def hook_func(self, m, gi, go): self.stored = go[0].detach().clone() def __enter__(self, *args): return self def __exit__(self, *args): self.hook.remove() """###讀取要繪製的影像""" fnames_Australian = get_image_files('/content/drive/MyDrive/dataset/mycat2/Tiger') test_dl = learn.dls.test_dl(fnames_Tiger, with_label=True) print(len(test_dl.get_idxs())) # pred_probas, _, pred_classes = learn.get_preds(dl=test_dl, with_decoded=True) test_dl.show_batch() from torchvision.transforms.functional import to_tensor fn = test_dl.items[0] x_dec = PILImage.create(fn); #Resize: 224 填充黑邊 rsz = Resize(224, method=ResizeMethod.Pad, pad_mode=PadMode.Zeros) x_dec = rsz(x_dec) x = to_tensor(x_dec) x.unsqueeze_(0) x.shape,type(x) """###繪製最後一層的feature map的Grad-CAM""" cls = 1 with HookBwd(learn.model[0]) as hookg: with Hook(learn.model[0]) as hook: # output = learn.model.eval()(x.cuda()) output = learn.model.eval()(x.cpu()) act = hook.stored output[0,cls].backward() grad = hookg.stored w = grad[0].mean(dim=[1,2], keepdim=True) cam_map = (w * act[0]).sum(0) _,ax = plt.subplots() x_dec.show(ctx=ax) ax.imshow(cam_map.detach().cpu(), alpha=0.6, extent=(0,224,224,0), interpolation='bilinear', cmap='magma'); """###看其他feature map的grad-cam""" with HookBwd(learn.model[0][-2]) as hookg: with Hook(learn.model[0][-2]) as hook: # output = learn.model.eval()(x.cuda()) output = learn.model.eval()(x.cpu()) act = hook.stored output[0,cls].backward() grad = hookg.stored w = grad[0].mean(dim=[1,2], keepdim=True) cam_map = (w * act[0]).sum(0) _,ax = plt.subplots() x_dec.show(ctx=ax) ax.imshow(cam_map.detach().cpu(), alpha=0.6, extent=(0,224,224,0), interpolation='bilinear', cmap='magma');