import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as transforms import timm from PIL import Image import matplotlib.pyplot as plt import os # Thanks to ( ), proxy can be essentail :) # os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:10809' # os.environ['HTTP_PROXY'] = 'http://127.0.0.1:10809' # os.environ['ALL_PROXY'] = 'socks5://127.0.0.1:10808' IMG_FILE_LIST = [ './testcases/14.jpg', './testcases/15.jpg', './testcases/16.jpg', './testcases/17.jpg', './testcases/18.jpg', './testcases/19.jpg' ] TANH_SCALE = 1 class Scorer(nn.Module): def __init__( self, model_name, pretrained=False, features_only=True, embedding_dim=128 ): super(Scorer, self).__init__() self.model = timm.create_model(model_name, pretrained=pretrained, features_only=features_only) pooled_dim = 128 + 256 + 512 + 1024 self.layer_norms = nn.ModuleList([ nn.LayerNorm(128), nn.LayerNorm(256), nn.LayerNorm(512), nn.LayerNorm(1024) ]) self.mlp = nn.Sequential( nn.Linear(pooled_dim, pooled_dim), nn.BatchNorm1d(pooled_dim), nn.GELU(), ) # Probably a BYOL-accidental BatchNorm could help ? self.mlp_1 = nn.Sequential( nn.Linear(pooled_dim, pooled_dim // 4), nn.BatchNorm1d(pooled_dim // 4), nn.GELU(), nn.Linear(pooled_dim // 4, 3), nn.Tanh() ) self.mlp_2 = nn.Sequential( nn.Linear(pooled_dim, pooled_dim // 4), nn.GELU(), nn.Linear(pooled_dim // 4, 1), ) def forward(self, x, upload_date=None, freeze_backbone=False): if freeze_backbone: with torch.no_grad(): out_features = self.model(x) else: out_features = self.model(x) # out_features: List [ # torch.Size([1, 128, x, x]) # torch.Size([1, 256, x, x]) # torch.Size([1, 512, x, x]) # torch.Size([1, 1024, x, x]) # ] # Pool the output features from each layer on the channel dimension pooled_features = [F.adaptive_avg_pool2d(x, 1).squeeze(-1).squeeze(-1) for x in out_features] # Normalize the pooled features pooled_features = [self.layer_norms[i](x) for i, x in enumerate(pooled_features)] # Embed the upload date # date_embedding_features = self.embedding(upload_date) # Concatenate the pooled features out = torch.cat(pooled_features, dim=-1) # Concatenate the date embedding features # out = torch.cat([out, date_embedding_features], dim=-1) out = self.mlp(out) rl_out = self.mlp_1(out) * TANH_SCALE ai_out = self.mlp_2(out).squeeze(-1) return rl_out[:, 0], rl_out[:, 1], F.sigmoid(ai_out), rl_out[:, 2] BACKBONE = 'convnextv2_base.fcmae' RESOLUTION = 640 SHOW_GRAD = False GRAD_SCALE = 50 MORE_LIKE = False MORE_COLLECTION = False LESS_AI = False MORE_RELATIVE_POP = True WEIGHT_PATH = './scorer.pt' DECIVE = 'cuda' def main(): model = Scorer(BACKBONE) transform = transforms.Compose([ transforms.Resize((RESOLUTION, RESOLUTION)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) model.load_state_dict(torch.load(WEIGHT_PATH)) model.eval() model.to(DECIVE) # Show all the images in pyplot horizontally, and mark the predicted values under each image fig = plt.figure(figsize=(20, 20)) for i, img_file in enumerate(IMG_FILE_LIST): img = Image.open(img_file, 'r').convert('RGB') transformed_img = transform(img).unsqueeze(0).to(DECIVE) transformed_img.requires_grad = True liking_pred, collection_pred, ai_pred, relative_pop = model(transformed_img, torch.tensor([1]), False) ax = fig.add_subplot(1, len(IMG_FILE_LIST), i + 1) backwardee = 0 if MORE_LIKE: backwardee -= liking_pred if MORE_COLLECTION: backwardee -= collection_pred if LESS_AI: backwardee += ai_pred if MORE_RELATIVE_POP: backwardee -= relative_pop if SHOW_GRAD: model.zero_grad() # Figure out which part of the image is the most important to popularity backwardee.backward() # Get the gradients of the image, and normalize them gradients = transformed_img.grad # squeeze the batch dimension gradients = gradients.squeeze(0).detach() # resize the gradients to the same size as the image gradients = transforms.Resize((img.height, img.width))(gradients) # add the gradients to the image img = transforms.ToTensor()(img) img = img + gradients.cpu() * GRAD_SCALE img = transforms.ToPILImage()(img.cpu()) ax.imshow(img) del img ax.set_title( f'Liking: {liking_pred.item():.3f}\nCollection: {collection_pred.item():.3f}\nAI: {ai_pred.item() * 100:.3f}%\nPopularity: {relative_pop.item():.3f}') plt.show() pass if __name__ == '__main__': main()