Add model script and pre-trained weights
Browse files
app.py
ADDED
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# model_part5_deploy.py
|
2 |
+
|
3 |
+
# 01. Import Packages {{{
|
4 |
+
import sys
|
5 |
+
import os
|
6 |
+
import time
|
7 |
+
sys.path.append('C:/python/my/pytorch/kaggle/pawpularity')
|
8 |
+
import shutil
|
9 |
+
|
10 |
+
import timm
|
11 |
+
from timm.data import resolve_data_config
|
12 |
+
from timm.data.transforms_factory import create_transform
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
from torch.utils.data import DataLoader, Dataset
|
17 |
+
|
18 |
+
import pandas as pd
|
19 |
+
import numpy as np
|
20 |
+
import matplotlib.pyplot as plt
|
21 |
+
from PIL import Image
|
22 |
+
from glob import glob
|
23 |
+
from tqdm import tqdm
|
24 |
+
import cv2
|
25 |
+
import gc # garbage collector, recycles memory usage
|
26 |
+
import albumentations as A # library for image augmentation
|
27 |
+
import gradio as gr
|
28 |
+
#}}}
|
29 |
+
# Gradio wrap {{{
|
30 |
+
title = "🐱用AI给小可爱们打个分🐶"
|
31 |
+
description = """
|
32 |
+
<center>
|
33 |
+
这是一个用AI来判断你小可爱照片有多受欢迎的小工具。作者希望能帮助动物救助组织给流浪的猫猫狗狗们更快找到一个温暖的家。
|
34 |
+
<img src="https://i.pinimg.com/564x/6f/9b/24/6f9b24e85d5bfb8acff726b5457bbd5c.jpg" width=200px>
|
35 |
+
</center>
|
36 |
+
"""
|
37 |
+
article = "此模型使用大约10000张已标注的宠物图片进行训练.目前只支持喵星人和汪星人输出结果。如果想要了解更多,请联系作者.微信:Roy_Ma_US."
|
38 |
+
def score(input_img):
|
39 |
+
thefile = input_img
|
40 |
+
# 02. Model constants {{{
|
41 |
+
device = torch.device('cuda')
|
42 |
+
class Config:
|
43 |
+
model_base_dir = 'D:/ML/pytorch/pretrained/'
|
44 |
+
model_file_ext = '/*.pth'
|
45 |
+
base_dir = "D:/ML/datasets/pawpularity"
|
46 |
+
data_dir = base_dir
|
47 |
+
output_dir = 'D:\ML\pytorch\pretrained\output'
|
48 |
+
img_test_dir = os.path.join(data_dir, "test")
|
49 |
+
im_size = 224
|
50 |
+
batch_size = 1 # match total test images needed for inference to at least 1 batch_size
|
51 |
+
num_workers = 0
|
52 |
+
# }}}
|
53 |
+
# 03. Define Dataset {{{
|
54 |
+
class PetDataset(Dataset):
|
55 |
+
def __init__(self, image_filepaths, targets, transform=None):
|
56 |
+
self.image_filepaths = image_filepaths
|
57 |
+
self.targets = targets
|
58 |
+
self.transform = transform
|
59 |
+
|
60 |
+
def __len__(self):
|
61 |
+
return len(self.image_filepaths)
|
62 |
+
|
63 |
+
def __getitem__(self, idx):
|
64 |
+
image_filepath = self.image_filepaths
|
65 |
+
with open(image_filepath, 'rb') as f:
|
66 |
+
image = Image.open(f)
|
67 |
+
image_rgb = image.convert('RGB') # convert image to RGB
|
68 |
+
image = np.array(image_rgb) # convert image to ndarray; current shape (H, W, C)
|
69 |
+
|
70 |
+
if self.transform is not None:
|
71 |
+
image = self.transform(image = image)["image"]
|
72 |
+
|
73 |
+
image = image / 255 # normalize to [0, 1]
|
74 |
+
image = np.transpose(image, (2, 0, 1)).astype(np.float32) # from HWC to CHW
|
75 |
+
target = self.targets
|
76 |
+
|
77 |
+
image = torch.tensor(image, dtype = torch.float)
|
78 |
+
target = torch.tensor(target, dtype = torch.float)
|
79 |
+
return image, target
|
80 |
+
|
81 |
+
|
82 |
+
def inference_fixed_transforms(mode=0, dim = 224):
|
83 |
+
if mode == 0: # keep original aspects, colors and angles
|
84 |
+
return A.Compose([
|
85 |
+
A.SmallestMaxSize(max_size=dim, p=1.0), # scale image with small edge = dim(224), with 100% probability
|
86 |
+
A.CenterCrop(height=dim, width=dim, p=1.0), # crop out a 224*224 area of the center, with 100% probability
|
87 |
+
], p=1.0)
|
88 |
+
elif mode == 1: # enlarge and flip
|
89 |
+
return A.Compose([
|
90 |
+
A.SmallestMaxSize(max_size=dim+16, p=1.0),
|
91 |
+
A.CenterCrop(height=dim, width=dim, p=1.0),
|
92 |
+
A.HorizontalFlip(p = 1.0)
|
93 |
+
], p=1.0)
|
94 |
+
# }}}
|
95 |
+
# 04. Define model class {{{
|
96 |
+
class PetNet(nn.Module):
|
97 |
+
def __init__(self, model_name, out_features = 1, inp_channels = 3, pretrained = False): # able to load weights into the layers
|
98 |
+
super().__init__() # inherit from nn.Module
|
99 |
+
self.model = timm.create_model(model_name, pretrained=False, in_chans=3, num_classes = 1) # create timm model with same init parameters
|
100 |
+
|
101 |
+
def forward(self, image):
|
102 |
+
output = self.model(image) # forward pass
|
103 |
+
return output # we have 1 out_feature here so the output is (batch_size, 1)
|
104 |
+
|
105 |
+
def tta_fn(filepaths, model, im_size, ttas=[0, 1]): # tta = Test Time Augmentation. Apply augs to testsets, inference and return an emsemble result
|
106 |
+
print('Image Size:', im_size)
|
107 |
+
model.eval() # set to eval mode
|
108 |
+
tta_preds = [] # create a list to store predictions
|
109 |
+
for tta_mode in ttas: # switch between 0 and 1 tta mode. current only mode 1 is passed into func
|
110 |
+
print(f'tta mode:{tta_mode}')
|
111 |
+
test_dataset = PetDataset(image_filepaths = filepaths, # full file path to 40 test images
|
112 |
+
targets = np.zeros(1), # create targets for testset, which is all 0
|
113 |
+
transform = inference_fixed_transforms(tta_mode, dim = im_size) # mode 1 transf is applied
|
114 |
+
)
|
115 |
+
test_loader = DataLoader(test_dataset, # just using torch's default dataloader
|
116 |
+
batch_size = Config.batch_size,
|
117 |
+
shuffle = False,
|
118 |
+
num_workers = Config.num_workers,
|
119 |
+
pin_memory = True
|
120 |
+
)
|
121 |
+
#stream = tqdm(test_loader)
|
122 |
+
tta_pred = []
|
123 |
+
for images, target in test_loader: # DataLoader loads batch_size # of images and targets at a time
|
124 |
+
images = images.to(device, non_blocking = True).float() # non_blocking may help with bottle neck if training include asynchronous data transfer
|
125 |
+
target = target.to(device, non_blocking = True).float().view(-1, 1) # -1 in view() is a placeholder. view(unkownn, 1), will always give the dims args passed to view()
|
126 |
+
with torch.no_grad(): # disable gradients. Note above model.eval()
|
127 |
+
output = model(images) # make raw prediction in inference mode
|
128 |
+
|
129 |
+
pred = (torch.sigmoid(output).detach().cpu().numpy() * 100).ravel().tolist() # detach creates a copy w/ no_grad; ravel = flatten; pred=16 values numpy list
|
130 |
+
tta_pred.extend(pred) # simliar to append, but also works with iterables, adds all element of the iterable
|
131 |
+
break
|
132 |
+
tta_preds.append(np.array(tta_pred)) # shape(40, 1)
|
133 |
+
|
134 |
+
fold_preds = tta_preds[0] # take only the 40 values, like flattening the list
|
135 |
+
for n in range(1, len(tta_preds)):
|
136 |
+
fold_preds += tta_preds[n]
|
137 |
+
fold_preds /= len(tta_preds)
|
138 |
+
|
139 |
+
del test_loader, test_dataset
|
140 |
+
gc.collect()
|
141 |
+
torch.cuda.empty_cache()
|
142 |
+
return fold_preds
|
143 |
+
# }}}
|
144 |
+
# 05. Inference exp53 {{{
|
145 |
+
class Config_exp53(Config):
|
146 |
+
model_dir = 'exp53'
|
147 |
+
model_name = "swin_large_patch4_window7_224"
|
148 |
+
|
149 |
+
test_preds = []
|
150 |
+
test_preds_model = []
|
151 |
+
modelfiles = glob(Config.model_base_dir + Config_exp53.model_dir + Config.model_file_ext) # get all model full file paths
|
152 |
+
|
153 |
+
for model_index, model_path in enumerate(modelfiles): # loop through all indexes (total 10) and model full file paths
|
154 |
+
print(f'inferencing with: {model_path}')
|
155 |
+
test_preds_fold = []
|
156 |
+
model = PetNet(model_name = Config_exp53.model_name, out_features = 1, inp_channels = 3, pretrained=False) # instantiate model
|
157 |
+
model.load_state_dict(torch.load(model_path)) # load coorponding weights into the model
|
158 |
+
model = model.to(device) # send to gpu
|
159 |
+
model = model.float() # convert to float
|
160 |
+
model.eval() # set to eval mode. (turn off BatchNorm, dropout etc.)
|
161 |
+
test_preds_fold = tta_fn(thefile, model, Config.im_size, [1]) # returns a list of predictions
|
162 |
+
test_preds_model.append(test_preds_fold) # append test.size # of predictions for each model
|
163 |
+
|
164 |
+
final_predictions53 = np.mean(np.array(test_preds_model), axis=0)
|
165 |
+
print(f'>>>exp53: ', final_predictions53)
|
166 |
+
# }}}
|
167 |
+
# 06. Inference exp55 {{{
|
168 |
+
class Config_exp55(Config):
|
169 |
+
model_dir = 'exp55'
|
170 |
+
model_name = "beit_large_patch16_224"
|
171 |
+
|
172 |
+
test_preds = []
|
173 |
+
test_preds_model = []
|
174 |
+
modelfiles = glob(Config.model_base_dir + Config_exp55.model_dir + Config.model_file_ext)
|
175 |
+
|
176 |
+
for model_index, model_path in enumerate(modelfiles):
|
177 |
+
print(f'inferencing with: {model_path}')
|
178 |
+
test_preds_fold = []
|
179 |
+
model = PetNet(model_name = Config_exp55.model_name, out_features = 1, inp_channels = 3, pretrained=False)
|
180 |
+
model.load_state_dict(torch.load(model_path))
|
181 |
+
model = model.to(device)
|
182 |
+
model = model.float()
|
183 |
+
model.eval()
|
184 |
+
test_preds_fold = tta_fn(thefile, model, Config.im_size, [0])
|
185 |
+
test_preds_model.append(test_preds_fold)
|
186 |
+
final_predictions55 = np.mean(np.array(test_preds_model), axis=0)
|
187 |
+
print(f'>>>exp55: ', final_predictions55)
|
188 |
+
# }}}
|
189 |
+
# 07. Inference exp66 {{{
|
190 |
+
class Config_exp66(Config):
|
191 |
+
model_dir = 'exp66'
|
192 |
+
model_name = "swin_large_patch4_window12_384_in22k"
|
193 |
+
im_size = 384
|
194 |
+
|
195 |
+
test_preds = []
|
196 |
+
test_preds_model = []
|
197 |
+
modelfiles = glob(Config.model_base_dir + Config_exp66.model_dir + Config.model_file_ext)
|
198 |
+
|
199 |
+
for model_index, model_path in enumerate(modelfiles):
|
200 |
+
print(f'inferencing with: {model_path}')
|
201 |
+
test_preds_fold = []
|
202 |
+
model = PetNet(model_name = Config_exp66.model_name, out_features = 1, inp_channels = 3, pretrained=False)
|
203 |
+
model.load_state_dict(torch.load(model_path))
|
204 |
+
model = model.to(device)
|
205 |
+
model = model.float()
|
206 |
+
model.eval()
|
207 |
+
test_preds_fold = tta_fn(thefile, model, Config_exp66.im_size, [0])
|
208 |
+
test_preds_model.append(test_preds_fold)
|
209 |
+
final_predictions66 = np.mean(np.array(test_preds_model), axis=0)
|
210 |
+
print(f'>>>exp66: ', final_predictions66)
|
211 |
+
#}}}
|
212 |
+
# 08. Inference exp77 {{{
|
213 |
+
class Config_exp77(Config):
|
214 |
+
model_dir = 'exp77'
|
215 |
+
model_name = "beit_large_patch16_224"
|
216 |
+
|
217 |
+
class PetNet_exp77(nn.Module):
|
218 |
+
def __init__(self, model_name, out_features = 1, inp_channels = 3, pretrained = False):
|
219 |
+
super().__init__()
|
220 |
+
NC = 1000
|
221 |
+
self.model = timm.create_model(model_name, pretrained=False)
|
222 |
+
self.dropout = nn.Dropout(0.05)
|
223 |
+
self.head = nn.Linear(NC, 1)
|
224 |
+
|
225 |
+
def forward(self, image):
|
226 |
+
output = self.model(image)
|
227 |
+
output = self.dropout(output)
|
228 |
+
output = self.head(output)
|
229 |
+
return output
|
230 |
+
|
231 |
+
test_preds = []
|
232 |
+
test_preds_model = []
|
233 |
+
modelfiles = glob(Config.model_base_dir + Config_exp77.model_dir + Config.model_file_ext)
|
234 |
+
|
235 |
+
for model_index, model_path in enumerate(modelfiles):
|
236 |
+
print(f'inferencing with: {model_path}')
|
237 |
+
test_preds_fold = []
|
238 |
+
model = PetNet_exp77(model_name = Config_exp77.model_name, out_features = 1, inp_channels = 3, pretrained=False)
|
239 |
+
model.load_state_dict(torch.load(model_path))
|
240 |
+
model = model.to(device)
|
241 |
+
model = model.float()
|
242 |
+
model.eval()
|
243 |
+
test_preds_fold = tta_fn(thefile, model, Config.im_size, [0])
|
244 |
+
test_preds_model.append(test_preds_fold)
|
245 |
+
final_predictions77 = np.mean(np.array(test_preds_model), axis=0)
|
246 |
+
print(f'>>>exp77: ', final_predictions77)
|
247 |
+
#}}}
|
248 |
+
# 09. Final predicted scores {{{
|
249 |
+
final_predictions = (3*final_predictions53 +
|
250 |
+
4*final_predictions55 +
|
251 |
+
3*final_predictions66 +
|
252 |
+
4*final_predictions77
|
253 |
+
) / (3+4+3+4) # take the mean of all predictions
|
254 |
+
boosted = (final_predictions + 20).round(2)
|
255 |
+
boosted = boosted[0]
|
256 |
+
#}}}
|
257 |
+
if boosted > 80:
|
258 |
+
boosted = str(boosted)
|
259 |
+
return boosted+"分! 美照!🥰"
|
260 |
+
else:
|
261 |
+
boosted = str(boosted)
|
262 |
+
return "大数据说"+boosted+"分. 多拍几张试试?🤗"
|
263 |
+
iface = gr.Interface(fn=score,
|
264 |
+
inputs=gr.inputs.Image(label="给哪位小可爱打分?😉", type='filepath'), # input_img.shape: h, w, c, if type=numpy
|
265 |
+
outputs=gr.outputs.Textbox(label="得分是...✨", type='str'),
|
266 |
+
allow_flagging="never",
|
267 |
+
title = title, description = description, article = article,
|
268 |
+
)
|
269 |
+
iface.launch()
|
270 |
+
#}}}
|
271 |
+
|
272 |
+
|
273 |
+
|
274 |
+
|
275 |
+
|
276 |
+
|
277 |
+
|
278 |
+
|
279 |
+
|
280 |
+
|
281 |
+
|
282 |
+
|
283 |
+
|
284 |
+
|
285 |
+
|
286 |
+
|
exp53/swin_large_patch4_window7_224_fold0_half.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:acbbfc7f8b47bb1c9cb7e012f26f5ca90a299b288afcd4fe5ea0021014010937
|
3 |
+
size 391161211
|
exp55/beit_large_patch16_224_fold0_half.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9d57c1faf222c48fc10e52ffa1a74fc50d16bbf677919dc486dada2ac41d3b96
|
3 |
+
size 614423238
|
exp66/swin_large_patch4_window12_384_in22k_fold0_half.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3908c5a171d1d3c0840ac11ae66e09767422ea821131a626dda0248e87f8563b
|
3 |
+
size 399341755
|
exp77/beit_large_patch16_224_fold0_half.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1d95198f26b8bdecdd42bd2e17dd0a5104567846c610c1af3e3ab42912dc6af5
|
3 |
+
size 616473816
|