SuShih commited on
Commit
78dfe98
1 Parent(s): 28c8c19

Upload 20221206.py

Browse files
Files changed (1) hide show
  1. 20221206.py +241 -0
20221206.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """20221206.ipynb
3
+
4
+ Automatically generated by Colaboratory.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1QJCfDvr9ofVasBT4JZbaZHr44QYCMuB0
8
+
9
+ ###1.1 安裝套件(若在colab訓練每次都需要執行)
10
+ """
11
+
12
+ !pip install fastbook -q
13
+
14
+ """###1.2 讀取套件"""
15
+
16
+ from fastbook import *
17
+ from fastai.vision.widgets import *
18
+ divice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
19
+ divice
20
+
21
+ import fastai;
22
+ print('fastai version:',fastai.__version__)
23
+ print('torch version:',torch.__version__)
24
+
25
+ """###1.3 準備資料集"""
26
+
27
+ from google.colab import drive
28
+ drive.mount('/content/drive')
29
+
30
+ path = Path('/content/drive/MyDrive/dataset/mycat2')
31
+ keywords = {'Lion': 'Lion','Tiger':'Tiger','Snow leopard':'Snow leopard'}
32
+ array = keywords.items()
33
+ if not path.exists():
34
+ !mkdir -p /content/drive/MyDrive/dataset/mycat2
35
+ for key,value in array:
36
+ print(key,value)
37
+ dest = (path/key)
38
+ dest.mkdir(exist_ok=True)
39
+ urls = search_images_ddg(f'{value}',max_images=300)
40
+ download_images (dest,urls=urls)
41
+
42
+ """###1.4 清洗資料"""
43
+
44
+ fns = get_image_files(path)
45
+ failed = verify_images(fns)
46
+ failed.map(Path.unlink) #unlink broken images
47
+
48
+ """###2.1. 設定訓練資料路徑"""
49
+
50
+ path = Path('/content/drive/MyDrive/dataset/mycat2')
51
+ path.ls()
52
+
53
+ #建立模型銓重儲存路徑
54
+ myPath = '/content/drive/MyDrive/dataset/models'
55
+ !mkdir -p $myPath
56
+
57
+ """###2.2. 資料讀取框架"""
58
+
59
+ dataset = DataBlock(
60
+ blocks=(ImageBlock,CategoryBlock),
61
+ get_items = get_image_files,
62
+ splitter = RandomSplitter(valid_pct=0.2,seed=42),
63
+ item_tfms = Resize(224),
64
+ get_y = parent_label
65
+ )
66
+
67
+ #利用框架正式讀取資料
68
+ dls = dataset.dataloaders(path,bs=16,num_workers=16)
69
+
70
+ #讀取結果
71
+ print(dls.c,dls.vocab,len(dls.train_ds),len(dls.valid_ds))
72
+
73
+ print('訓練資料')
74
+ dls.show_batch(max_n=5, nrows=1,unique=True)
75
+
76
+ dls.show_batch(max_n=5,nrows=1)
77
+
78
+ """###3.1. 選擇模型架構以及對應的預訓練權重
79
+
80
+ ###Note: metrics是模型訓練人員觀察的指標, 可設定多個
81
+ """
82
+
83
+ learn = vision_learner(dls, resnet34, metrics=[accuracy, error_rate], pretrained=True)
84
+ learn.fit_one_cycle(3, 1e-3)
85
+
86
+ """###3.2 儲存第一次訓練好的權重"""
87
+
88
+ myModel=myPath+'/resnet34_stage-1.pkl'
89
+ learn.export(myModel)
90
+
91
+ """###3.3 解凍權重再次訓練"""
92
+
93
+ learn.unfreeze()
94
+ lr_min,lr_steep = learn.lr_find(suggest_funcs=(minimum, steep))
95
+ print(f"Minimum/10: {lr_min:.2e}, steepest point: {lr_steep:.2e}")
96
+
97
+ learn.fit_one_cycle(6, lr_max= 2.75e-04)
98
+
99
+ """###3.4 儲存新的權重"""
100
+
101
+ myModel=myPath+'/resnet34_stage-2.pkl'
102
+ learn.export(myModel)
103
+
104
+ """###4.1. 結果檢核(Confusion Matrix)"""
105
+
106
+ # Contains interpretation methods for classification models
107
+ interp = ClassificationInterpretation.from_learner(learn)
108
+ # Plot the confusion matrix
109
+ interp.plot_confusion_matrix()
110
+
111
+ interp.plot_top_losses(5, nrows=1)
112
+
113
+ """###4.2 ROC Curve and AUC"""
114
+
115
+ preds,y, loss = learn.get_preds(with_loss=True)
116
+ # get accuracy
117
+ acc = accuracy(preds, y)
118
+ print('The accuracy is {0} %.'.format(acc))
119
+
120
+ from sklearn.metrics import roc_curve, auc
121
+ # probs from log preds
122
+ probs = np.exp(preds[:,1])
123
+ # Compute ROC curve
124
+ fpr, tpr, thresholds = roc_curve(y, probs, pos_label=1)
125
+
126
+ # Compute ROC area
127
+ roc_auc = auc(fpr, tpr)
128
+ print('ROC area is {0}'.format(roc_auc))
129
+
130
+ plt.figure()
131
+ plt.plot(fpr, tpr, color='darkorange', label='ROC curve (area = %0.2f)' % roc_auc)
132
+ plt.plot([0, 1], [0, 1], color='navy', linestyle='--')
133
+ plt.xlim([-0.01, 1.01])
134
+ plt.ylim([-0.01, 1.01])
135
+ plt.axis('square')
136
+ plt.xlabel('False Positive Rate')
137
+ plt.ylabel('True Positive Rate')
138
+ plt.title('Receiver operating characteristic')
139
+ plt.legend(loc="lower right")
140
+
141
+ """###5. 預測"""
142
+
143
+ ## 模型位置
144
+ from fastbook import *
145
+ from fastai.vision.widgets import *
146
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
147
+ device
148
+
149
+ """###5.1 讀取先前訓練好的權重"""
150
+
151
+ myPath='/content/drive/MyDrive/dataset/models'
152
+ myModel=myPath+'/resnet34_stage-2.pkl'
153
+ learn = load_learner(myModel)
154
+
155
+ """###5.2 讀取檔案並送入模型預測"""
156
+
157
+ ## 執行預測 - method I
158
+ fnames_Lion = get_image_files('/content/drive/MyDrive/dataset/mycat2/Lion')
159
+ fnames_Tiger = get_image_files('/content/drive/MyDrive/dataset/mycat2/Tiger')
160
+ fnames_Snow_leopard = get_image_files('/content/drive/MyDrive/dataset/mycat2/Snow leopard')
161
+
162
+ fnames_Lion
163
+
164
+ pred_class,pred_idx,outputs = learn.predict(fnames_Lion[3])
165
+ print("Actual: Lion, Predicted = {}".format(pred_class))
166
+
167
+ pred_class,pred_idx,outputs = learn.predict(fnames_Tiger[7])
168
+ print("Actual: Tiger, Predicted = {}".format(pred_class))
169
+
170
+ """###6. Visualization with Grad-CAM"""
171
+
172
+ class Hook():
173
+ def __init__(self, m):
174
+ self.hook = m.register_forward_hook(self.hook_func)
175
+ def hook_func(self, m, i, o): self.stored = o.detach().clone()
176
+ def __enter__(self, *args): return self
177
+ def __exit__(self, *args): self.hook.remove()
178
+
179
+ class HookBwd():
180
+ def __init__(self, m):
181
+ self.hook = m.register_backward_hook(self.hook_func)
182
+ def hook_func(self, m, gi, go): self.stored = go[0].detach().clone()
183
+ def __enter__(self, *args): return self
184
+ def __exit__(self, *args): self.hook.remove()
185
+
186
+ """###讀取要繪製的影像"""
187
+
188
+ fnames_Australian = get_image_files('/content/drive/MyDrive/dataset/mycat2/Tiger')
189
+ test_dl = learn.dls.test_dl(fnames_Tiger, with_label=True)
190
+ print(len(test_dl.get_idxs()))
191
+ # pred_probas, _, pred_classes = learn.get_preds(dl=test_dl, with_decoded=True)
192
+
193
+ test_dl.show_batch()
194
+
195
+ from torchvision.transforms.functional import to_tensor
196
+ fn = test_dl.items[0]
197
+ x_dec = PILImage.create(fn);
198
+
199
+ #Resize: 224 填充黑邊
200
+ rsz = Resize(224, method=ResizeMethod.Pad, pad_mode=PadMode.Zeros)
201
+ x_dec = rsz(x_dec)
202
+ x = to_tensor(x_dec)
203
+ x.unsqueeze_(0)
204
+ x.shape,type(x)
205
+
206
+ """###繪製最後一層的feature map的Grad-CAM"""
207
+
208
+ cls = 1
209
+ with HookBwd(learn.model[0]) as hookg:
210
+ with Hook(learn.model[0]) as hook:
211
+ # output = learn.model.eval()(x.cuda())
212
+ output = learn.model.eval()(x.cpu())
213
+ act = hook.stored
214
+ output[0,cls].backward()
215
+ grad = hookg.stored
216
+
217
+ w = grad[0].mean(dim=[1,2], keepdim=True)
218
+ cam_map = (w * act[0]).sum(0)
219
+
220
+ _,ax = plt.subplots()
221
+ x_dec.show(ctx=ax)
222
+ ax.imshow(cam_map.detach().cpu(), alpha=0.6, extent=(0,224,224,0),
223
+ interpolation='bilinear', cmap='magma');
224
+
225
+ """###看其他feature map的grad-cam"""
226
+
227
+ with HookBwd(learn.model[0][-2]) as hookg:
228
+ with Hook(learn.model[0][-2]) as hook:
229
+ # output = learn.model.eval()(x.cuda())
230
+ output = learn.model.eval()(x.cpu())
231
+ act = hook.stored
232
+ output[0,cls].backward()
233
+ grad = hookg.stored
234
+
235
+ w = grad[0].mean(dim=[1,2], keepdim=True)
236
+ cam_map = (w * act[0]).sum(0)
237
+
238
+ _,ax = plt.subplots()
239
+ x_dec.show(ctx=ax)
240
+ ax.imshow(cam_map.detach().cpu(), alpha=0.6, extent=(0,224,224,0),
241
+ interpolation='bilinear', cmap='magma');