File size: 5,380 Bytes
f7e3261 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import timm
from PIL import Image
import matplotlib.pyplot as plt
import os
# Thanks to ( ), proxy can be essentail :)
# os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:10809'
# os.environ['HTTP_PROXY'] = 'http://127.0.0.1:10809'
# os.environ['ALL_PROXY'] = 'socks5://127.0.0.1:10808'
IMG_FILE_LIST = [
'./testcases/14.jpg',
'./testcases/15.jpg',
'./testcases/16.jpg',
'./testcases/17.jpg',
'./testcases/18.jpg',
'./testcases/19.jpg'
]
TANH_SCALE = 1
class Scorer(nn.Module):
def __init__(
self,
model_name,
pretrained=False,
features_only=True,
embedding_dim=128
):
super(Scorer, self).__init__()
self.model = timm.create_model(model_name, pretrained=pretrained, features_only=features_only)
pooled_dim = 128 + 256 + 512 + 1024
self.layer_norms = nn.ModuleList([
nn.LayerNorm(128),
nn.LayerNorm(256),
nn.LayerNorm(512),
nn.LayerNorm(1024)
])
self.mlp = nn.Sequential(
nn.Linear(pooled_dim, pooled_dim),
nn.BatchNorm1d(pooled_dim),
nn.GELU(),
)
# Probably a BYOL-accidental BatchNorm could help ?
self.mlp_1 = nn.Sequential(
nn.Linear(pooled_dim, pooled_dim // 4),
nn.BatchNorm1d(pooled_dim // 4),
nn.GELU(),
nn.Linear(pooled_dim // 4, 3),
nn.Tanh()
)
self.mlp_2 = nn.Sequential(
nn.Linear(pooled_dim, pooled_dim // 4),
nn.GELU(),
nn.Linear(pooled_dim // 4, 1),
)
def forward(self, x, upload_date=None, freeze_backbone=False):
if freeze_backbone:
with torch.no_grad():
out_features = self.model(x)
else:
out_features = self.model(x)
# out_features: List [
# torch.Size([1, 128, x, x])
# torch.Size([1, 256, x, x])
# torch.Size([1, 512, x, x])
# torch.Size([1, 1024, x, x])
# ]
# Pool the output features from each layer on the channel dimension
pooled_features = [F.adaptive_avg_pool2d(x, 1).squeeze(-1).squeeze(-1) for x in out_features]
# Normalize the pooled features
pooled_features = [self.layer_norms[i](x) for i, x in enumerate(pooled_features)]
# Embed the upload date
# date_embedding_features = self.embedding(upload_date)
# Concatenate the pooled features
out = torch.cat(pooled_features, dim=-1)
# Concatenate the date embedding features
# out = torch.cat([out, date_embedding_features], dim=-1)
out = self.mlp(out)
rl_out = self.mlp_1(out) * TANH_SCALE
ai_out = self.mlp_2(out).squeeze(-1)
return rl_out[:, 0], rl_out[:, 1], F.sigmoid(ai_out), rl_out[:, 2]
BACKBONE = 'convnextv2_base.fcmae'
RESOLUTION = 640
SHOW_GRAD = False
GRAD_SCALE = 50
MORE_LIKE = False
MORE_COLLECTION = False
LESS_AI = False
MORE_RELATIVE_POP = True
WEIGHT_PATH = './scorer.pt'
DECIVE = 'cuda'
def main():
model = Scorer(BACKBONE)
transform = transforms.Compose([
transforms.Resize((RESOLUTION, RESOLUTION)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
model.load_state_dict(torch.load(WEIGHT_PATH))
model.eval()
model.to(DECIVE)
# Show all the images in pyplot horizontally, and mark the predicted values under each image
fig = plt.figure(figsize=(20, 20))
for i, img_file in enumerate(IMG_FILE_LIST):
img = Image.open(img_file, 'r').convert('RGB')
transformed_img = transform(img).unsqueeze(0).to(DECIVE)
transformed_img.requires_grad = True
liking_pred, collection_pred, ai_pred, relative_pop = model(transformed_img, torch.tensor([1]), False)
ax = fig.add_subplot(1, len(IMG_FILE_LIST), i + 1)
backwardee = 0
if MORE_LIKE:
backwardee -= liking_pred
if MORE_COLLECTION:
backwardee -= collection_pred
if LESS_AI:
backwardee += ai_pred
if MORE_RELATIVE_POP:
backwardee -= relative_pop
if SHOW_GRAD:
model.zero_grad()
# Figure out which part of the image is the most important to popularity
backwardee.backward()
# Get the gradients of the image, and normalize them
gradients = transformed_img.grad
# squeeze the batch dimension
gradients = gradients.squeeze(0).detach()
# resize the gradients to the same size as the image
gradients = transforms.Resize((img.height, img.width))(gradients)
# add the gradients to the image
img = transforms.ToTensor()(img)
img = img + gradients.cpu() * GRAD_SCALE
img = transforms.ToPILImage()(img.cpu())
ax.imshow(img)
del img
ax.set_title(
f'Liking: {liking_pred.item():.3f}\nCollection: {collection_pred.item():.3f}\nAI: {ai_pred.item() * 100:.3f}%\nPopularity: {relative_pop.item():.3f}')
plt.show()
pass
if __name__ == '__main__':
main() |