Spidartist commited on
Commit
88ae77c
1 Parent(s): c2fe698

Upload 7 files

Browse files
IJEPA_finetune.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+
4
+ import numpy as np
5
+ import pytorch_lightning as pl
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops import rearrange
9
+ from torchmetrics.functional import accuracy
10
+ from torchmetrics.functional.classification import multiclass_recall, multiclass_precision
11
+ from x_transformers import Encoder, Decoder
12
+
13
+ ON_EPOCH = True
14
+ ON_STEP = False
15
+ BATCH_SIZE = 64
16
+ TARGET_SIZE = (64, 64)
17
+ SPLIT_RATE = 0.8
18
+ ROOT_DIR_DATA = "/kaggle/input/ant-data-new/data"
19
+
20
+
21
+ class PatchEmbed(nn.Module):
22
+ """Image to Patch Embedding"""
23
+
24
+ def __init__(self, img_size=TARGET_SIZE[0], patch_size=4, in_chans=3, embed_dim=64):
25
+ super().__init__()
26
+ if isinstance(img_size, int):
27
+ img_size = img_size, img_size
28
+ if isinstance(patch_size, int):
29
+ patch_size = patch_size, patch_size
30
+
31
+ # calculate the number of patches
32
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
33
+
34
+ # convolutional layer to convert the image into patches
35
+ self.conv = nn.Conv2d(
36
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
37
+ )
38
+
39
+ def forward(self, x):
40
+ x = self.conv(x)
41
+ # flatten the patches
42
+ x = rearrange(x, 'b e h w -> b (h w) e')
43
+ return x
44
+
45
+
46
+ class ViTIJEPA(nn.Module):
47
+ def __init__(self, img_size, patch_size, in_chans, embed_dim, enc_depth, num_heads,
48
+ num_classes, post_emb_norm=False,
49
+ layer_dropout=0.):
50
+ super().__init__()
51
+ self.layer_dropout = layer_dropout
52
+ self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
53
+ self.num_tokens = self.patch_embed.patch_shape[0] * self.patch_embed.patch_shape[1]
54
+ self.pos_embedding = nn.Parameter(torch.randn(1, self.num_tokens, embed_dim))
55
+ self.post_emb_norm = nn.LayerNorm(embed_dim) if post_emb_norm else nn.Identity()
56
+ self.student_encoder = Encoder(
57
+ dim=embed_dim,
58
+ heads=num_heads,
59
+ depth=enc_depth,
60
+ layer_dropout=self.layer_dropout,
61
+ flash=True
62
+ )
63
+
64
+ self.average_pool = nn.AvgPool1d((embed_dim), stride=1)
65
+ # mlp head
66
+ self.mlp_head = nn.Sequential(
67
+ nn.LayerNorm(self.num_tokens),
68
+ nn.Linear(self.num_tokens, num_classes),
69
+ )
70
+
71
+ def forward(self, x):
72
+ x = self.patch_embed(x)
73
+ b, n, e = x.shape
74
+ # add the positional embeddings
75
+ x = x + self.pos_embedding
76
+ # normalize the embeddings
77
+ x = self.post_emb_norm(x)
78
+ # if mode is test, we get return full embedding:
79
+ x = self.student_encoder(x)
80
+
81
+ x = self.average_pool(x) # conduct average pool like in paper
82
+ x = x.squeeze(-1)
83
+ x = self.mlp_head(x) # pass through mlp head
84
+ return x
85
+
app.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from IJEPA_finetune import ViTIJEPA
3
+ import torch
4
+ from einops import rearrange
5
+ from torchvision.transforms import Compose
6
+ import torchvision
7
+
8
+ classes = ['Acanthostichus',
9
+ 'Aenictus',
10
+ 'Amblyopone',
11
+ 'Attini',
12
+ 'Bothriomyrmecini',
13
+ 'Camponotini',
14
+ 'Cerapachys',
15
+ 'Cheliomyrmex',
16
+ 'Crematogastrini',
17
+ 'Cylindromyrmex',
18
+ 'Dolichoderini',
19
+ 'Dorylus',
20
+ 'Eciton',
21
+ 'Ectatommini',
22
+ 'Formicini',
23
+ 'Fulakora',
24
+ 'Gesomyrmecini',
25
+ 'Gigantiopini',
26
+ 'Heteroponerini',
27
+ 'Labidus',
28
+ 'Lasiini',
29
+ 'Leptomyrmecini',
30
+ 'Lioponera',
31
+ 'Melophorini',
32
+ 'Myopopone',
33
+ 'Myrmecia',
34
+ 'Myrmelachistini',
35
+ 'Myrmicini',
36
+ 'Myrmoteratini',
37
+ 'Mystrium',
38
+ 'Neivamyrmex',
39
+ 'Nomamyrmex',
40
+ 'Oecophyllini',
41
+ 'Ooceraea',
42
+ 'Paraponera',
43
+ 'Parasyscia',
44
+ 'Plagiolepidini',
45
+ 'Platythyreini',
46
+ 'Pogonomyrmecini',
47
+ 'Ponerini',
48
+ 'Prionopelta',
49
+ 'Probolomyrmecini',
50
+ 'Proceratiini',
51
+ 'Pseudomyrmex',
52
+ 'Solenopsidini',
53
+ 'Stenammini',
54
+ 'Stigmatomma',
55
+ 'Syscia',
56
+ 'Tapinomini',
57
+ 'Tetraponera',
58
+ 'Zasphinctus']
59
+ class_to_idx = {idx: cls for idx, cls in enumerate(classes)}
60
+
61
+ tf = Compose([torchvision.transforms.Resize((64, 64), antialias=True)])
62
+
63
+ model = ViTIJEPA(64, 4, 3, 64, 8, 8, len(classes))
64
+ model.load_state_dict(torch.load("vit_ijepa_ant_1.pt", map_location=torch.device('cpu')))
65
+
66
+
67
+ def ant_genus_classification(image):
68
+ image = torch.Tensor(image)
69
+ image = image.unsqueeze(0)
70
+ image = rearrange(image, 'b h w c -> b c h w')
71
+ image = tf(image)
72
+
73
+ print(image.shape)
74
+ with torch.no_grad():
75
+ prediction = torch.nn.functional.softmax(model(image)[0], dim=0)
76
+ # print(prediction.tolist())
77
+ confidences = {class_to_idx[i]: float(prediction[i]) for i in range(len(classes))}
78
+ return confidences
79
+ # prediction = model(image)[0]
80
+ # prediction = prediction.tolist()
81
+ # print(prediction)
82
+ # return {
83
+ # class_to_idx[i]: prediction[i] for i in range(len(prediction)) if prediction[i] > 0.01
84
+ # }
85
+
86
+
87
+ demo = gr.Interface(fn=ant_genus_classification, inputs="image", outputs=gr.Label(num_top_classes=3))
88
+
89
+ if __name__ == "__main__":
90
+ demo.launch(debug=True)
downstream-ant-epoch=82-val_loss=0.07.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:481ca7a89125f57b57dc11d4fa111294f670cfb6e8b1b183bc9eb6922fc87d81
3
+ size 25619801
sample-cifar10-epoch=399-ant.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1895c1679dfccce8093531f650ad9b7fb888c8bd335e2b7e1d1eab2d6fc87bae
3
+ size 33430443
vit_ijepa_ant_1.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d00f67fe6537693dfae4b35841643a54e893e3a21a5b91fd954e0718e1836982
3
+ size 5438855