Spaces:
Sleeping
Sleeping
Spidartist
commited on
Commit
•
88ae77c
1
Parent(s):
c2fe698
Upload 7 files
Browse files- IJEPA_finetune.py +85 -0
- app.py +90 -0
- downstream-ant-epoch=82-val_loss=0.07.ckpt +3 -0
- sample-cifar10-epoch=399-ant.ckpt +3 -0
- vit_ijepa_ant_1.pt +3 -0
IJEPA_finetune.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import os
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pytorch_lightning as pl
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from einops import rearrange
|
9 |
+
from torchmetrics.functional import accuracy
|
10 |
+
from torchmetrics.functional.classification import multiclass_recall, multiclass_precision
|
11 |
+
from x_transformers import Encoder, Decoder
|
12 |
+
|
13 |
+
ON_EPOCH = True
|
14 |
+
ON_STEP = False
|
15 |
+
BATCH_SIZE = 64
|
16 |
+
TARGET_SIZE = (64, 64)
|
17 |
+
SPLIT_RATE = 0.8
|
18 |
+
ROOT_DIR_DATA = "/kaggle/input/ant-data-new/data"
|
19 |
+
|
20 |
+
|
21 |
+
class PatchEmbed(nn.Module):
|
22 |
+
"""Image to Patch Embedding"""
|
23 |
+
|
24 |
+
def __init__(self, img_size=TARGET_SIZE[0], patch_size=4, in_chans=3, embed_dim=64):
|
25 |
+
super().__init__()
|
26 |
+
if isinstance(img_size, int):
|
27 |
+
img_size = img_size, img_size
|
28 |
+
if isinstance(patch_size, int):
|
29 |
+
patch_size = patch_size, patch_size
|
30 |
+
|
31 |
+
# calculate the number of patches
|
32 |
+
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
33 |
+
|
34 |
+
# convolutional layer to convert the image into patches
|
35 |
+
self.conv = nn.Conv2d(
|
36 |
+
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
|
37 |
+
)
|
38 |
+
|
39 |
+
def forward(self, x):
|
40 |
+
x = self.conv(x)
|
41 |
+
# flatten the patches
|
42 |
+
x = rearrange(x, 'b e h w -> b (h w) e')
|
43 |
+
return x
|
44 |
+
|
45 |
+
|
46 |
+
class ViTIJEPA(nn.Module):
|
47 |
+
def __init__(self, img_size, patch_size, in_chans, embed_dim, enc_depth, num_heads,
|
48 |
+
num_classes, post_emb_norm=False,
|
49 |
+
layer_dropout=0.):
|
50 |
+
super().__init__()
|
51 |
+
self.layer_dropout = layer_dropout
|
52 |
+
self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
53 |
+
self.num_tokens = self.patch_embed.patch_shape[0] * self.patch_embed.patch_shape[1]
|
54 |
+
self.pos_embedding = nn.Parameter(torch.randn(1, self.num_tokens, embed_dim))
|
55 |
+
self.post_emb_norm = nn.LayerNorm(embed_dim) if post_emb_norm else nn.Identity()
|
56 |
+
self.student_encoder = Encoder(
|
57 |
+
dim=embed_dim,
|
58 |
+
heads=num_heads,
|
59 |
+
depth=enc_depth,
|
60 |
+
layer_dropout=self.layer_dropout,
|
61 |
+
flash=True
|
62 |
+
)
|
63 |
+
|
64 |
+
self.average_pool = nn.AvgPool1d((embed_dim), stride=1)
|
65 |
+
# mlp head
|
66 |
+
self.mlp_head = nn.Sequential(
|
67 |
+
nn.LayerNorm(self.num_tokens),
|
68 |
+
nn.Linear(self.num_tokens, num_classes),
|
69 |
+
)
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
x = self.patch_embed(x)
|
73 |
+
b, n, e = x.shape
|
74 |
+
# add the positional embeddings
|
75 |
+
x = x + self.pos_embedding
|
76 |
+
# normalize the embeddings
|
77 |
+
x = self.post_emb_norm(x)
|
78 |
+
# if mode is test, we get return full embedding:
|
79 |
+
x = self.student_encoder(x)
|
80 |
+
|
81 |
+
x = self.average_pool(x) # conduct average pool like in paper
|
82 |
+
x = x.squeeze(-1)
|
83 |
+
x = self.mlp_head(x) # pass through mlp head
|
84 |
+
return x
|
85 |
+
|
app.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from IJEPA_finetune import ViTIJEPA
|
3 |
+
import torch
|
4 |
+
from einops import rearrange
|
5 |
+
from torchvision.transforms import Compose
|
6 |
+
import torchvision
|
7 |
+
|
8 |
+
classes = ['Acanthostichus',
|
9 |
+
'Aenictus',
|
10 |
+
'Amblyopone',
|
11 |
+
'Attini',
|
12 |
+
'Bothriomyrmecini',
|
13 |
+
'Camponotini',
|
14 |
+
'Cerapachys',
|
15 |
+
'Cheliomyrmex',
|
16 |
+
'Crematogastrini',
|
17 |
+
'Cylindromyrmex',
|
18 |
+
'Dolichoderini',
|
19 |
+
'Dorylus',
|
20 |
+
'Eciton',
|
21 |
+
'Ectatommini',
|
22 |
+
'Formicini',
|
23 |
+
'Fulakora',
|
24 |
+
'Gesomyrmecini',
|
25 |
+
'Gigantiopini',
|
26 |
+
'Heteroponerini',
|
27 |
+
'Labidus',
|
28 |
+
'Lasiini',
|
29 |
+
'Leptomyrmecini',
|
30 |
+
'Lioponera',
|
31 |
+
'Melophorini',
|
32 |
+
'Myopopone',
|
33 |
+
'Myrmecia',
|
34 |
+
'Myrmelachistini',
|
35 |
+
'Myrmicini',
|
36 |
+
'Myrmoteratini',
|
37 |
+
'Mystrium',
|
38 |
+
'Neivamyrmex',
|
39 |
+
'Nomamyrmex',
|
40 |
+
'Oecophyllini',
|
41 |
+
'Ooceraea',
|
42 |
+
'Paraponera',
|
43 |
+
'Parasyscia',
|
44 |
+
'Plagiolepidini',
|
45 |
+
'Platythyreini',
|
46 |
+
'Pogonomyrmecini',
|
47 |
+
'Ponerini',
|
48 |
+
'Prionopelta',
|
49 |
+
'Probolomyrmecini',
|
50 |
+
'Proceratiini',
|
51 |
+
'Pseudomyrmex',
|
52 |
+
'Solenopsidini',
|
53 |
+
'Stenammini',
|
54 |
+
'Stigmatomma',
|
55 |
+
'Syscia',
|
56 |
+
'Tapinomini',
|
57 |
+
'Tetraponera',
|
58 |
+
'Zasphinctus']
|
59 |
+
class_to_idx = {idx: cls for idx, cls in enumerate(classes)}
|
60 |
+
|
61 |
+
tf = Compose([torchvision.transforms.Resize((64, 64), antialias=True)])
|
62 |
+
|
63 |
+
model = ViTIJEPA(64, 4, 3, 64, 8, 8, len(classes))
|
64 |
+
model.load_state_dict(torch.load("vit_ijepa_ant_1.pt", map_location=torch.device('cpu')))
|
65 |
+
|
66 |
+
|
67 |
+
def ant_genus_classification(image):
|
68 |
+
image = torch.Tensor(image)
|
69 |
+
image = image.unsqueeze(0)
|
70 |
+
image = rearrange(image, 'b h w c -> b c h w')
|
71 |
+
image = tf(image)
|
72 |
+
|
73 |
+
print(image.shape)
|
74 |
+
with torch.no_grad():
|
75 |
+
prediction = torch.nn.functional.softmax(model(image)[0], dim=0)
|
76 |
+
# print(prediction.tolist())
|
77 |
+
confidences = {class_to_idx[i]: float(prediction[i]) for i in range(len(classes))}
|
78 |
+
return confidences
|
79 |
+
# prediction = model(image)[0]
|
80 |
+
# prediction = prediction.tolist()
|
81 |
+
# print(prediction)
|
82 |
+
# return {
|
83 |
+
# class_to_idx[i]: prediction[i] for i in range(len(prediction)) if prediction[i] > 0.01
|
84 |
+
# }
|
85 |
+
|
86 |
+
|
87 |
+
demo = gr.Interface(fn=ant_genus_classification, inputs="image", outputs=gr.Label(num_top_classes=3))
|
88 |
+
|
89 |
+
if __name__ == "__main__":
|
90 |
+
demo.launch(debug=True)
|
downstream-ant-epoch=82-val_loss=0.07.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:481ca7a89125f57b57dc11d4fa111294f670cfb6e8b1b183bc9eb6922fc87d81
|
3 |
+
size 25619801
|
sample-cifar10-epoch=399-ant.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1895c1679dfccce8093531f650ad9b7fb888c8bd335e2b7e1d1eab2d6fc87bae
|
3 |
+
size 33430443
|
vit_ijepa_ant_1.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d00f67fe6537693dfae4b35841643a54e893e3a21a5b91fd954e0718e1836982
|
3 |
+
size 5438855
|