|
import gradio as gr |
|
from IJEPA_finetune import ViTIJEPA |
|
import torch |
|
from einops import rearrange |
|
from torchvision.transforms import Compose |
|
import torchvision |
|
|
|
class CMYKToRGB(object): |
|
def __call__(self, img): |
|
|
|
if not isinstance(img, torch.Tensor): |
|
raise TypeError("Input image should be a torch.Tensor") |
|
|
|
|
|
if img.shape[-3] != 4: |
|
return img |
|
|
|
|
|
c, m, y, k = img.unbind(-3) |
|
|
|
|
|
r = 255 * (1 - c) * (1 - k) |
|
g = 255 * (1 - m) * (1 - k) |
|
b = 255 * (1 - y) * (1 - k) |
|
|
|
|
|
rgb_img = torch.stack([r, g, b], dim=-3) |
|
|
|
return rgb_img |
|
|
|
classes = ['Acanthostichus', |
|
'Aenictus', |
|
'Amblyopone', |
|
'Attini', |
|
'Bothriomyrmecini', |
|
'Camponotini', |
|
'Cerapachys', |
|
'Cheliomyrmex', |
|
'Crematogastrini', |
|
'Cylindromyrmex', |
|
'Dolichoderini', |
|
'Dorylus', |
|
'Eciton', |
|
'Ectatommini', |
|
'Formicini', |
|
'Fulakora', |
|
'Gesomyrmecini', |
|
'Gigantiopini', |
|
'Heteroponerini', |
|
'Labidus', |
|
'Lasiini', |
|
'Leptomyrmecini', |
|
'Lioponera', |
|
'Melophorini', |
|
'Myopopone', |
|
'Myrmecia', |
|
'Myrmelachistini', |
|
'Myrmicini', |
|
'Myrmoteratini', |
|
'Mystrium', |
|
'Neivamyrmex', |
|
'Nomamyrmex', |
|
'Oecophyllini', |
|
'Ooceraea', |
|
'Paraponera', |
|
'Parasyscia', |
|
'Plagiolepidini', |
|
'Platythyreini', |
|
'Pogonomyrmecini', |
|
'Ponerini', |
|
'Prionopelta', |
|
'Probolomyrmecini', |
|
'Proceratiini', |
|
'Pseudomyrmex', |
|
'Solenopsidini', |
|
'Stenammini', |
|
'Stigmatomma', |
|
'Syscia', |
|
'Tapinomini', |
|
'Tetraponera', |
|
'Zasphinctus'] |
|
class_to_idx = {idx: cls for idx, cls in enumerate(classes)} |
|
|
|
train_transforms = torchvision.transforms.Compose( |
|
[ |
|
torchvision.transforms.ToTensor(), |
|
torchvision.transforms.Resize((64, 64), antialias=True), |
|
CMYKToRGB(), |
|
] |
|
) |
|
|
|
|
|
model = ViTIJEPA(64, 4, 3, 64, 8, 8, len(classes)) |
|
model.load_state_dict(torch.load("vit_ijepa_ant_1.pt", map_location=torch.device('cpu'))) |
|
|
|
|
|
def ant_genus_classification(image): |
|
image = train_transforms(image).unsqueeze(0) |
|
|
|
print(image.shape) |
|
with torch.no_grad(): |
|
y_hat = model(image) |
|
preds = torch.nn.functional.softmax(y_hat, dim=1) |
|
|
|
print(preds.shape) |
|
|
|
confidences = {class_to_idx[i]: float(preds[0][i]) for i in range(len(classes))} |
|
return confidences |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo = gr.Interface(fn=ant_genus_classification, inputs="image", outputs=gr.Label(num_top_classes=10)) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(debug=True) |
|
|