import gradio as gr from IJEPA_finetune import ViTIJEPA import torch from einops import rearrange from torchvision.transforms import Compose import torchvision classes = ['Acanthostichus', 'Aenictus', 'Amblyopone', 'Attini', 'Bothriomyrmecini', 'Camponotini', 'Cerapachys', 'Cheliomyrmex', 'Crematogastrini', 'Cylindromyrmex', 'Dolichoderini', 'Dorylus', 'Eciton', 'Ectatommini', 'Formicini', 'Fulakora', 'Gesomyrmecini', 'Gigantiopini', 'Heteroponerini', 'Labidus', 'Lasiini', 'Leptomyrmecini', 'Lioponera', 'Melophorini', 'Myopopone', 'Myrmecia', 'Myrmelachistini', 'Myrmicini', 'Myrmoteratini', 'Mystrium', 'Neivamyrmex', 'Nomamyrmex', 'Oecophyllini', 'Ooceraea', 'Paraponera', 'Parasyscia', 'Plagiolepidini', 'Platythyreini', 'Pogonomyrmecini', 'Ponerini', 'Prionopelta', 'Probolomyrmecini', 'Proceratiini', 'Pseudomyrmex', 'Solenopsidini', 'Stenammini', 'Stigmatomma', 'Syscia', 'Tapinomini', 'Tetraponera', 'Zasphinctus'] class_to_idx = {idx: cls for idx, cls in enumerate(classes)} tf = Compose([torchvision.transforms.Resize((64, 64), antialias=True)]) model = ViTIJEPA(64, 4, 3, 64, 8, 8, len(classes)) model.load_state_dict(torch.load("vit_ijepa_ant_1.pt", map_location=torch.device('cpu'))) def ant_genus_classification(image): image = torch.Tensor(image) image = image.unsqueeze(0) image = rearrange(image, 'b h w c -> b c h w') image = tf(image) print(image.shape) with torch.no_grad(): prediction = torch.nn.functional.softmax(model(image)[0], dim=0) # print(prediction.tolist()) confidences = {class_to_idx[i]: float(prediction[i]) for i in range(len(classes))} return confidences # prediction = model(image)[0] # prediction = prediction.tolist() # print(prediction) # return { # class_to_idx[i]: prediction[i] for i in range(len(prediction)) if prediction[i] > 0.01 # } demo = gr.Interface(fn=ant_genus_classification, inputs="image", outputs=gr.Label(num_top_classes=3)) if __name__ == "__main__": demo.launch(debug=True)