import copy import os import numpy as np import pytorch_lightning as pl import torch import torch.nn as nn from einops import rearrange from torchmetrics.functional import accuracy from torchmetrics.functional.classification import multiclass_recall, multiclass_precision from x_transformers import Encoder, Decoder ON_EPOCH = True ON_STEP = False BATCH_SIZE = 64 TARGET_SIZE = (64, 64) SPLIT_RATE = 0.8 ROOT_DIR_DATA = "/kaggle/input/ant-data-new/data" class PatchEmbed(nn.Module): """Image to Patch Embedding""" def __init__(self, img_size=TARGET_SIZE[0], patch_size=4, in_chans=3, embed_dim=64): super().__init__() if isinstance(img_size, int): img_size = img_size, img_size if isinstance(patch_size, int): patch_size = patch_size, patch_size # calculate the number of patches self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) # convolutional layer to convert the image into patches self.conv = nn.Conv2d( in_chans, embed_dim, kernel_size=patch_size, stride=patch_size ) def forward(self, x): x = self.conv(x) # flatten the patches x = rearrange(x, 'b e h w -> b (h w) e') return x class ViTIJEPA(nn.Module): def __init__(self, img_size, patch_size, in_chans, embed_dim, enc_depth, num_heads, num_classes, post_emb_norm=False, layer_dropout=0.): super().__init__() self.layer_dropout = layer_dropout self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) self.num_tokens = self.patch_embed.patch_shape[0] * self.patch_embed.patch_shape[1] self.pos_embedding = nn.Parameter(torch.randn(1, self.num_tokens, embed_dim)) self.post_emb_norm = nn.LayerNorm(embed_dim) if post_emb_norm else nn.Identity() self.student_encoder = Encoder( dim=embed_dim, heads=num_heads, depth=enc_depth, layer_dropout=self.layer_dropout, flash=True ) self.average_pool = nn.AvgPool1d((embed_dim), stride=1) # mlp head self.mlp_head = nn.Sequential( nn.LayerNorm(self.num_tokens), nn.Linear(self.num_tokens, num_classes), ) def forward(self, x): x = self.patch_embed(x) b, n, e = x.shape # add the positional embeddings x = x + self.pos_embedding # normalize the embeddings x = self.post_emb_norm(x) # if mode is test, we get return full embedding: x = self.student_encoder(x) x = self.average_pool(x) # conduct average pool like in paper x = x.squeeze(-1) x = self.mlp_head(x) # pass through mlp head return x