vardaan123 commited on
Commit
3dba732
1 Parent(s): c3e2aa9

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ dataset_subtree.csv filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
  title: COSMO
3
- emoji: 😻
4
  colorFrom: green
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 4.14.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: COSMO
3
+ emoji: 🦀
4
  colorFrom: green
5
+ colorTo: yellow
6
+ sdk: streamlit
7
+ sdk_version: 1.29.0
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
__pycache__/model.cpython-38.pyc ADDED
Binary file (3.57 kB). View file
 
app.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import gradio as gr
6
+ from model import DistMult
7
+ from PIL import Image
8
+ from torchvision import transforms
9
+ import json
10
+ from tqdm import tqdm
11
+
12
+ # Default image tensor normalization
13
+ _DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN = [0.485, 0.456, 0.406]
14
+ _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD = [0.229, 0.224, 0.225]
15
+
16
+ def generate_target_list(data, entity2id):
17
+ sub = data.loc[(data["datatype_h"] == "image") & (data["datatype_t"] == "id"), ['t']]
18
+ sub = list(sub['t'])
19
+ categories = []
20
+ for item in tqdm(sub):
21
+ if entity2id[str(int(float(item)))] not in categories:
22
+ categories.append(entity2id[str(int(float(item)))])
23
+ # print('categories = {}'.format(categories))
24
+ # print("No. of target categories = {}".format(len(categories)))
25
+ return torch.tensor(categories, dtype=torch.long).unsqueeze(-1)
26
+
27
+ # Load necessary data and initialize the model
28
+ entity2id = json.load(open('entity2id_subtree.json', 'r'))
29
+ id2entity = {v: k for k, v in entity2id.items()}
30
+ datacsv = pd.read_csv('dataset_subtree.csv', low_memory=False)
31
+ num_ent_id = len(entity2id)
32
+ target_list = generate_target_list(datacsv, entity2id) # Assuming this function is defined elsewhere
33
+ overall_id_to_name = json.load(open('overall_id_to_name.json'))
34
+
35
+ # Initialize your model here
36
+ model = DistMult(num_ent_id, target_list, torch.device('cpu')) # Update arguments as necessary
37
+ model.eval()
38
+
39
+ ckpt = torch.load('species_class_model.pt', map_location=torch.device('cpu'))
40
+ model.load_state_dict(ckpt['model'], strict=False)
41
+ print('ckpt loaded...')
42
+
43
+ # Define your evaluation function
44
+ def evaluate(img):
45
+ transform_steps = transforms.Compose([
46
+ transforms.ToPILImage(),
47
+ transforms.Resize((448, 448)),
48
+ transforms.ToTensor(),
49
+ transforms.Normalize(_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN, _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD)
50
+ ])
51
+ h = transform_steps(img)
52
+ r = torch.tensor([3])
53
+
54
+ # Assuming `move_to` is a function to move tensors to the desired device
55
+ h = h.unsqueeze(0)
56
+ r = r.unsqueeze(0)
57
+
58
+ outputs = F.softmax(model.forward_ce(h, r, triple_type=('image', 'id')), dim=-1)
59
+
60
+ # print('outputs = {}'.format(outputs.size()))
61
+
62
+ predictions = torch.topk(outputs, k=5, dim=-1).indices.squeeze(0).tolist()
63
+
64
+ # print('predictions', predictions)
65
+
66
+ result = {}
67
+ for i in predictions:
68
+ pred_label = target_list[i].item()
69
+ label = overall_id_to_name[str(id2entity[pred_label])]
70
+ prob = outputs[0, i].item()
71
+ result[label] = prob
72
+
73
+ # y_pred = outputs.argmax(-1).cpu()
74
+ # pred_label = target_list[y_pred].item()
75
+ # species_label = overall_id_to_name[str(id2entity[pred_label])]
76
+
77
+ # print('pred_label', pred_label)
78
+ # print('species_label', species_label)
79
+
80
+ # return species_label
81
+ return result
82
+
83
+ # Gradio interface
84
+ species_model = gr.Interface(
85
+ evaluate,
86
+ gr.inputs.Image(shape=(200, 200)),
87
+ outputs="label",
88
+ title='Camera Trap Species Classification demo',
89
+ # description='Species Classification',
90
+ # article='Species Classification'
91
+ )
92
+ species_model.launch(server_port=8977,share=True, debug=True)
dataset_subtree.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e89acb42f04c5593c492cf836ccf6b897d22e76c52b3a262c8e462813fb82cda
3
+ size 43352089
entity2id_subtree.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"93302": 0, "805080": 1, "304358": 2, "332573": 3, "5246131": 4, "691846": 5, "641038": 6, "117569": 7, "147604": 8, "125642": 9, "947318": 10, "801601": 11, "278114": 12, "114656": 13, "114654": 14, "458402": 15, "4940726": 16, "229562": 17, "229560": 18, "244265": 19, "229558": 20, "683263": 21, "847764": 22, "495017": 23, "273244": 24, "796671": 25, "796672": 26, "495018": 27, "495016": 28, "847766": 29, "490533": 30, "490538": 31, "273230": 32, "273227": 33, "5334778": 34, "392222": 35, "392220": 36, "644242": 37, "644258": 38, "864604": 39, "634572": 40, "864610": 41, "747873": 42, "864593": 43, "7067181": 44, "839752": 45, "671055": 46, "831080": 47, "1068783": 48, "922511": 49, "816256": 50, "320824": 51, "254744": 52, "254745": 53, "889045": 54, "271598": 55, "717794": 56, "782350": 57, "782347": 58, "3610624": 59, "3611156": 60, "844553": 61, "185338": 62, "23048": 63, "23039": 64, "666969": 65, "666961": 66, "976847": 67, "976856": 68, "1068778": 69, "237403": 70, "845966": 71, "671049": 72, "392236": 73, "764826": 74, "407000": 75, "44557": 76, "220323": 77, "173067": 78, "276723": 79, "220325": 80, "67361": 81, "220326": 82, "410922": 83, "848923": 84, "848914": 85, "592588": 86, "438471": 87, "438474": 88, "296191": 89, "44559": 90, "384218": 91, "630990": 92, "649553": 93, "866983": 94, "421036": 95, "970404": 96, "394011": 97, "474585": 98, "3609124": 99, "319614": 100, "524854": 101, "173836": 102, "765432": 103, "201068": 104, "970408": 105, "173811": 106, "675197": 107, "675198": 108, "913935": 109, "702152": 110, "386195": 111, "842867": 112, "386191": 113, "770311": 114, "312031": 115, "417957": 116, "417950": 117, "386194": 118, "842868": 119, "741061": 120, "989398": 121, "512437": 122, "842860": 123, "115460": 124, "115449": 125, "268324": 126, "837394": 127, "268346": 128, "203191": 129, "386004": 130, "571323": 131, "392223": 132, "622916": 133, "7655791": 134, "7655792": 135, "510764": 136, "510761": 137, "510762": 138, "986971": 139, "403912": 140, "768685": 141, "768687": 142, "768674": 143, "460505": 144, "534970": 145, "844149": 146, "844145": 147, "534996": 148, "194503": 149, "194523": 150, "194507": 151, "1030872": 152, "1030860": 153, "3611950": 154, "92562": 155, "410156": 156, "410145": 157, "768677": 158, "122647": 159, "19014": 160, "19015": 161, "122641": 162, "798021": 163, "540244": 164, "70819": 165, "346071": 166, "122649": 167, "644255": 168, "70831": 169, "561121": 170, "70827": 171, "70832": 172, "1066581": 173, "490099": 174, "385449": 175, "989809": 176, "989807": 177, "910691": 178, "768678": 179, "768679": 180, "70835": 181, "1016642": 182, "346068": 183, "513794": 184, "513789": 185, "591989": 186, "40168": 187, "1036727": 188, "702522": 189, "513800": 190, "122644": 191, "122645": 192, "591984": 193, "591988": 194, "98208": 195, "591990": 196, "591987": 197, "380144": 198, "436155": 199, "510773": 200, "510775": 201, "510767": 202, "510752": 203, "916745": 204, "730004": 205, "637442": 206, "1037242": 207, "1037247": 208, "906307": 209, "730008": 210, "995191": 211, "995183": 212, "1036752": 213, "1036755": 214, "730021": 215, "730013": 216, "644252": 217, "644249": 218, "644247": 219, "644245": 220, "44565": 221, "827263": 222, "297458": 223, "297460": 224, "679701": 225, "445986": 226, "231614": 227, "1023230": 228, "348043": 229, "67323": 230, "381139": 231, "381140": 232, "348045": 233, "213517": 234, "770319": 235, "837603": 236, "313163": 237, "821952": 238, "821973": 239, "821959": 240, "821953": 241, "372706": 242, "666235": 243, "621176": 244, "247341": 245, "264179": 246, "685113": 247, "3612582": 248, "821960": 249, "872571": 250, "348029": 251, "348040": 252, "914060": 253, "348030": 254, "736280": 255, "348031": 256, "827259": 257, "397138": 258, "397140": 259, "397157": 260, "397160": 261, "563161": 262, "383900": 263, "383901": 264, "397144": 265, "5681": 266, "211399": 267, "159587": 268, "350016": 269, "194343": 270, "5685": 271, "194345": 272, "194340": 273, "5686": 274, "397135": 275, "397136": 276, "252751": 277, "194342": 278, "194349": 279, "42311": 280, "159578": 281, "159576": 282, "42306": 283, "3613295": 284, "563159": 285, "626916": 286, "570215": 287, "280108": 288, "1033548": 289, "1033549": 290, "86169": 291, "86170": 292, "86161": 293, "86162": 294, "42307": 295, "563165": 296, "563163": 297, "774314": 298, "507553": 299, "752746": 300, "626917": 301, "763018": 302, "882766": 303, "86186": 304, "660452": 305, "563154": 306, "42314": 307, "42322": 308, "42324": 309, "563151": 310, "626920": 311, "752758": 312, "752759": 313, "541948": 314, "1070066": 315, "541951": 316, "94003": 317, "520756": 318, "615442": 319, "1068209": 320, "1068227": 321, "1087514": 322, "1034223": 323, "6146951": 324, "9419": 325, "746703": 326, "561107": 327, "561109": 328, "561113": 329, "561114": 330, "561100": 331, "561103": 332, "561106": 333, "561087": 334, "226176": 335, "541924": 336, "541933": 337, "541936": 338, "16033": 339, "277697": 340, "16069": 341, "608046": 342, "393366": 343, "170433": 344, "762047": 345, "919176": 346, "362785": 347, "639642": 348, "329823": 349, "35881": 350, "35888": 351, "4945781": 352, "4945815": 353, "4945816": 354, "4945872": 355, "139516": 356, "4945873": 357, "4945874": 358, "713776": 359, "713772": 360, "872963": 361, "4947372": 362, "335588": 363, "90215": 364, "90223": 365, "664350": 366, "664351": 367, "81461": 368, "241846": 369, "363030": 370, "938413": 371, "931109": 372, "150851": 373, "664463": 374, "244142": 375, "1032057": 376, "1032049": 377, "83286": 378, "604964": 379, "449653": 380, "664480": 381, "539139": 382, "539141": 383, "843074": 384, "772741": 385, "5839486": 386, "241841": 387, "765193": 388, "7068148": 389, "860117": 390, "693339": 391, "837585": 392, "684043": 393, "684045": 394, "684040": 395, "109893": 396, "109881": 397, "157741": 398, "109892": 399, "109882": 400, "979429": 401, "132829": 402, "728070": 403, "51353": 404, "102704": 405, "110936": 406, "521341": 407, "521339": 408, "624441": 409, "53692": 410, "53708": 411, "781250": 412, "446481": 413, "446490": 414, "136462": 415, "446477": 416, "3596058": 417, "204731": 418, "1080967": 419, "352754": 420, "352755": 421, "969837": 422, "609781": 423, "307211": 424, "3596764": 425, "938409": 426, "489432": 427, "989084": 428, "989081": 429, "4947835": 430, "1036185": 431, "786440": 432, "584448": 433, "266054": 434, "313124": 435, "261310": 436, "261316": 437, "427706": 438, "967304": 439, "966318": 440, "5025": 441, "5021": 442, "5030": 443, "3600024": 444, "521835": 445, "521834": 446, "966314": 447, "521837": 448, "414340": 449, "381374": 450, "906602": 451, "1041547": 452, "131990": 453, "774534": 454, "3598135": 455, "96286": 456, "568571": 457, "96367": 458, "176458": 459, "28338": 460, "695334": 461, "645461": 462, "7068500": 463, "3599375": 464, "1076202": 465, "451623": 466, "259942": 467, "1051167": 468, "907909": 469, "106895": 470, "635217": 471, "187411": 472, "320098": 473, "3598028": 474, "81443": 475, "292467": 476, "292469": 477, "402450": 478, "402466": 479, "857847": 480, "857849": 481, "292466": 482, "647692": 483, "8032375": 484, "8032203": 485, "8032276": 486, "8032351": 487, "8032318": 488, "8032251": 489, "8032381": 490, "8032285": 491, "8032224": 492, "8032345": 493, "8032358": 494, "8032284": 495, "8032368": 496, "8032286": 497, "8032289": 498, "8032377": 499, "8032372": 500, "8032362": 501, "8032369": 502, "8032295": 503, "8032363": 504, "8032294": 505, "8032384": 506, "8032383": 507, "8032326": 508, "8032325": 509, "8032234": 510}
model.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+ from torch import Tensor
5
+ from typing import Tuple
6
+
7
+ from torchvision.models import resnet18, resnet50
8
+ from torchvision.models import ResNet18_Weights, ResNet50_Weights
9
+
10
+ class DistMult(nn.Module):
11
+ def __init__(self, num_ent_uid, target_list, device, all_locs=None, num_habitat=None, all_timestamps=None):
12
+ super(DistMult, self).__init__()
13
+ self.num_ent_uid = num_ent_uid
14
+
15
+ self.num_relations = 4
16
+
17
+ self.ent_embedding = torch.nn.Embedding(self.num_ent_uid, 512, sparse=False)
18
+ self.rel_embedding = torch.nn.Embedding(self.num_relations, 512, sparse=False)
19
+
20
+ self.location_embedding = MLP(2, 512, 3)
21
+
22
+ self.time_embedding = MLP(1, 512, 3)
23
+
24
+ self.image_embedding = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
25
+ self.image_embedding.fc = nn.Linear(2048, 512)
26
+
27
+ self.target_list = target_list
28
+
29
+ if all_locs is not None:
30
+ self.all_locs = all_locs.to(device)
31
+ if all_timestamps is not None:
32
+ self.all_timestamps = all_timestamps.to(device)
33
+
34
+ self.device = device
35
+
36
+ self.init()
37
+
38
+ def init(self):
39
+ nn.init.xavier_uniform_(self.ent_embedding.weight.data)
40
+ nn.init.xavier_uniform_(self.rel_embedding.weight.data)
41
+ nn.init.xavier_uniform_(self.image_embedding.fc.weight.data)
42
+
43
+ def forward_ce(self, h, r, triple_type=None):
44
+ emb_h = self.batch_embedding_concat_h(h) # [batch, hid]
45
+
46
+ emb_r = self.rel_embedding(r.squeeze(-1)) # [batch, hid]
47
+
48
+ emb_hr = emb_h * emb_r # [batch, hid]
49
+
50
+ if triple_type == ('image', 'id'):
51
+ score = torch.mm(emb_hr, self.ent_embedding.weight[self.target_list.squeeze(-1)].T) # [batch, n_ent]
52
+ elif triple_type == ('id', 'id'):
53
+ score = torch.mm(emb_hr, self.ent_embedding.weight.T) # [batch, n_ent]
54
+ elif triple_type == ('image', 'location'):
55
+ loc_emb = self.location_embedding(self.all_locs) # computed for each batch
56
+ score = torch.mm(emb_hr, loc_emb.T)
57
+ elif triple_type == ('image', 'time'):
58
+ time_emb = self.time_embedding(self.all_timestamps)
59
+ score = torch.mm(emb_hr, time_emb.T)
60
+ else:
61
+ raise NotImplementedError
62
+
63
+ return score
64
+
65
+ def batch_embedding_concat_h(self, e1):
66
+ e1_embedded = None
67
+
68
+ if len(e1.size())==1 or e1.size(1) == 1: # uid
69
+ # print('ent_embedding = {}'.format(self.ent_embedding.weight.size()))
70
+ e1_embedded = self.ent_embedding(e1.squeeze(-1))
71
+ elif e1.size(1) == 15: # time
72
+ e1_embedded = self.time_embedding(e1)
73
+ elif e1.size(1) == 2: # GPS
74
+ e1_embedded = self.location_embedding(e1)
75
+ elif e1.size(1) == 3: # Image
76
+ e1_embedded = self.image_embedding(e1)
77
+
78
+ return e1_embedded
79
+
80
+
81
+ class MLP(nn.Module):
82
+ def __init__(self,
83
+ input_dim,
84
+ output_dim,
85
+ num_layers=3,
86
+ p_dropout=0.0,
87
+ bias=True):
88
+
89
+ super().__init__()
90
+
91
+ self.input_dim = input_dim
92
+ self.output_dim = output_dim
93
+
94
+ self.p_dropout = p_dropout
95
+ step_size = (input_dim - output_dim) // num_layers
96
+ hidden_dims = [output_dim + (i * step_size)
97
+ for i in reversed(range(num_layers))]
98
+
99
+ mlp = list()
100
+ layer_indim = input_dim
101
+ for hidden_dim in hidden_dims:
102
+ mlp.extend([nn.Linear(layer_indim, hidden_dim, bias),
103
+ nn.Dropout(p=self.p_dropout, inplace=True),
104
+ nn.PReLU()])
105
+
106
+ layer_indim = hidden_dim
107
+
108
+ self.mlp = nn.Sequential(*mlp)
109
+
110
+ # initialize weights
111
+ self.init()
112
+
113
+ def forward(self, x):
114
+ return self.mlp(x)
115
+
116
+ def init(self):
117
+ for param in self.parameters():
118
+ nn.init.uniform_(param)
overall_id_to_name.json ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "8032375": "motorcycle",
3
+ "8032203": "empty",
4
+ "8032276": "pardofelis temminckii",
5
+ "8032351": "agouti paca",
6
+ "8032318": "cercopithecus lhoesti",
7
+ "8032251": "equus quagga",
8
+ "8032381": "ave desconocida",
9
+ "8032285": "unknown bird",
10
+ "8032224": "mazama gouazoubira",
11
+ "8032345": "francolinus africanus",
12
+ "8032358": "mazama pandora",
13
+ "8032284": "canis familiaris",
14
+ "8032368": "lophura sp",
15
+ "8032286": "unknown bat",
16
+ "8032289": "geotrygon sp",
17
+ "8032377": "puma yagoroundi",
18
+ "8032372": "myiophoneus caeruleus",
19
+ "8032362": "arctonyx hoevenii",
20
+ "8032369": "myiophoneus glaucinus",
21
+ "8032295": "brotogeris sp",
22
+ "8032363": "tragulus sp",
23
+ "8032294": "phaetornis sp",
24
+ "8032384": "mazama temama",
25
+ "8032383": "unknown dove",
26
+ "8032326": "andropadus virens",
27
+ "8032325": "andropadus latirostris",
28
+ "8032234": "herpestes sanguineus",
29
+ "906307": "tayassu pecari",
30
+ "848914": "dasyprocta punctata",
31
+ "296191": "cuniculus paca",
32
+ "42307": "puma concolor",
33
+ "1034223": "tapirus terrestris",
34
+ "1037242": "pecari tajacu",
35
+ "1030860": "mazama americana",
36
+ "752746": "leopardus pardalis",
37
+ "664480": "geotrygon montana",
38
+ "348031": "nasua nasua",
39
+ "796672": "dasypus novemcinctus",
40
+ "381140": "eira barbara",
41
+ "919176": "didelphis marsupialis",
42
+ "914060": "procyon cancrivorus",
43
+ "42322": "panthera onca",
44
+ "490538": "myrmecophaga tridactyla",
45
+ "402466": "tinamus major",
46
+ "634572": "sylvilagus brasiliensis",
47
+ "86162": "puma yagouaroundi",
48
+ "507553": "leopardus wiedii",
49
+ "170433": "philander opossum",
50
+ "19015": "capra aegagrus",
51
+ "490099": "bos taurus",
52
+ "70819": "ovis aries",
53
+ "247341": "canis lupus",
54
+ "747873": "lepus saxatilis",
55
+ "115449": "papio anubis",
56
+ "194343": "genetta genetta",
57
+ "561121": "tragelaphus scriptus",
58
+ "541936": "loxodonta africana",
59
+ "922511": "cricetomys gambianus",
60
+ "513789": "raphicerus campestris",
61
+ "383901": "hyaena hyaena",
62
+ "768679": "aepyceros melampus",
63
+ "397157": "crocuta crocuta",
64
+ "1033549": "caracal caracal",
65
+ "520756": "equus ferus",
66
+ "563151": "panthera leo",
67
+ "70832": "tragelaphus oryx",
68
+ "122645": "kobus ellipsiprymnus",
69
+ "1036755": "phacochoerus africanus",
70
+ "42324": "panthera pardus",
71
+ "159576": "ichneumia albicauda",
72
+ "666235": "canis mesomelas",
73
+ "644255": "syncerus caffer",
74
+ "768674": "giraffa camelopardalis",
75
+ "989807": "alcelaphus buselaphus",
76
+ "571323": "chlorocebus pygerythrus",
77
+ "40168": "madoqua guentheri",
78
+ "995183": "potamochoerus larvatus",
79
+ "346068": "nanger granti",
80
+ "702522": "eudorcas thomsonii",
81
+ "647692": "struthio camelus",
82
+ "561087": "orycteropus afer",
83
+ "752759": "acinonyx jubatus",
84
+ "521834": "eupodotis senegalensis",
85
+ "563163": "felis silvestris",
86
+ "98208": "oryx beisa",
87
+ "3600024": "lophotis gindiana",
88
+ "521837": "ardeotis kori",
89
+ "5021": "lissotis melanogaster",
90
+ "521339": "argusianus argus",
91
+ "280108": "prionailurus bengalensis",
92
+ "194340": "hemigalus derbyanus",
93
+ "194523": "muntiacus muntjak",
94
+ "730013": "sus scrofa",
95
+ "679701": "helarctos malayanus",
96
+ "844145": "rusa unicolor",
97
+ "67361": "hystrix brachyura",
98
+ "42314": "panthera tigris",
99
+ "201068": "lariscus insignis",
100
+ "1032049": "chalcophaps indica",
101
+ "350016": "genetta tigrina",
102
+ "220326": "hystrix cristata",
103
+ "821953": "lycaon pictus",
104
+ "561114": "procavia capensis",
105
+ "989081": "momotus momota",
106
+ "592588": "dasyprocta fuliginosa",
107
+ "736280": "nasua narica",
108
+ "273227": "tamandua mexicana",
109
+ "362785": "didelphis sp",
110
+ "157741": "penelope purpurascens",
111
+ "510752": "camelus dromedarius",
112
+ "821973": "otocyon megalotis",
113
+ "684040": "acryllium vulturinum",
114
+ "1068209": "equus grevyi",
115
+ "563161": "proteles cristata",
116
+ "86170": "leptailurus serval",
117
+ "70827": "tragelaphus strepsiceros",
118
+ "510762": "hippopotamus amphibius",
119
+ "427706": "burhinus capensis",
120
+ "397136": "paguma larvata",
121
+ "660452": "pardofelis marmorata",
122
+ "313163": "cuon alpinus",
123
+ "872963": "varanus salvator",
124
+ "213517": "martes flavigula",
125
+ "194349": "prionodon linsang",
126
+ "352755": "rollulus rouloul",
127
+ "53708": "lophura inornata",
128
+ "110936": "polyplectron chalcurum",
129
+ "644245": "manis javanica",
130
+ "798021": "capricornis sumatraensis",
131
+ "837394": "macaca sp",
132
+ "1080967": "francolinus nobilis",
133
+ "436155": "cephalophus nigrifrons",
134
+ "276723": "atherurus africanus",
135
+ "417950": "pan troglodytes",
136
+ "203191": "cercopithecus mitis",
137
+ "524854": "funisciurus carruthersi",
138
+ "645461": "motacilla flava",
139
+ "3611156": "thamnomys venustus",
140
+ "675198": "protoxerus stangeri",
141
+ "3609124": "paraxerus boehmi",
142
+ "380144": "cephalophus silvicultor",
143
+ "976856": "oenomys hypoxanthus",
144
+ "106895": "melocichla mentalis",
145
+ "666961": "hybomys univittatus",
146
+ "23039": "colomys goslingi",
147
+ "185338": "hylomyscus stella",
148
+ "159587": "genetta servalina",
149
+ "621176": "canis adustus",
150
+ "845966": "mus minutoides",
151
+ "772741": "musophaga rossae",
152
+ "150851": "turtur tympanistria",
153
+ "717794": "praomys tullbergi",
154
+ "782347": "malacomys longipes",
155
+ "693339": "alopochen aegyptiaca",
156
+ "254745": "deomys ferrugineus",
157
+ "96367": "turdus olivaceus",
158
+ "92562": "mazama sp",
159
+ "685113": "urocyon cinereoargenteus",
160
+ "446490": "meleagris ocellata",
161
+ "132829": "crax rubra",
162
+ "9419": "tapirus bairdii",
163
+ "348040": "procyon lotor",
164
+ "410145": "odocoileus virginianus",
165
+ "244142": "leptotila plumbeiceps",
166
+ "3611950": "mazama temama",
167
+ "1023230": "conepatus semistriatus",
168
+ "109882": "ortalis vetula",
169
+ "512437": "presbytis thomasi",
170
+ "882766": "neofelis diardi",
171
+ "3598028": "dendrocitta occipitalis",
172
+ "3598135": "niltava sumatrana",
173
+ "451623": "leiothrix argentauris",
174
+ "3596058": "arborophila rubrirostris",
175
+ "53692": "lophura erythrophthalma",
176
+ "266054": "spilornis cheela",
177
+ "3613295": "herpestes semitorquatus",
178
+ "821960": "cerdocyon thous",
179
+ "407000": "peromyscus sp",
180
+ "3596764": "tigrisoma mexicanum",
181
+ "604964": "claravis pretiosa",
182
+ "421036": "sciurus sp",
183
+ "906602": "aramides cajanea"
184
+ }
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.0.0
2
+ pandas==1.5.3
3
+ numpy==1.24.2
4
+ Pillow==9.4.0
5
+ scipy==1.10.1
6
+ tensorboard==2.12.2
7
+ torchvision==0.15.1
8
+ tqdm==4.64.1
9
+ wilds==2.0.0
10
+ matplotlib==3.7.1
11
+ gradio==3.50.0
species_class_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:429e1eac2a4cc58b6e3ed0c660ed93b01525628530e5822a076ebf6d878120b9
3
+ size 301166871
utils.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import csv
4
+ import argparse
5
+ import random
6
+ from pathlib import Path
7
+ import numpy as np
8
+ import torch
9
+ import pandas as pd
10
+ import re
11
+
12
+ from torch.utils.data import DataLoader
13
+
14
+ try:
15
+ from torch_geometric.data import Batch
16
+ except ImportError:
17
+ pass
18
+
19
+ def set_seed(seed):
20
+ """Sets seed"""
21
+ if torch.cuda.is_available():
22
+ torch.cuda.manual_seed(seed)
23
+ torch.manual_seed(seed)
24
+ np.random.seed(seed)
25
+ random.seed(seed)
26
+ torch.backends.cudnn.benchmark = False
27
+ torch.backends.cudnn.deterministic = True
28
+
29
+
30
+ def move_to(obj, device):
31
+ if isinstance(obj, dict):
32
+ return {k: move_to(v, device) for k, v in obj.items()}
33
+ elif isinstance(obj, list):
34
+ return [move_to(v, device) for v in obj]
35
+ elif isinstance(obj, float) or isinstance(obj, int):
36
+ return obj
37
+ else:
38
+ # Assume obj is a Tensor or other type
39
+ # (like Batch, for MolPCBA) that supports .to(device)
40
+ return obj.to(device)
41
+
42
+ def detach_and_clone(obj):
43
+ if torch.is_tensor(obj):
44
+ return obj.detach().clone()
45
+ elif isinstance(obj, dict):
46
+ return {k: detach_and_clone(v) for k, v in obj.items()}
47
+ elif isinstance(obj, list):
48
+ return [detach_and_clone(v) for v in obj]
49
+ elif isinstance(obj, float) or isinstance(obj, int):
50
+ return obj
51
+ else:
52
+ raise TypeError("Invalid type for detach_and_clone")
53
+
54
+ def collate_list(vec):
55
+ """
56
+ If vec is a list of Tensors, it concatenates them all along the first dimension.
57
+
58
+ If vec is a list of lists, it joins these lists together, but does not attempt to
59
+ recursively collate. This allows each element of the list to be, e.g., its own dict.
60
+
61
+ If vec is a list of dicts (with the same keys in each dict), it returns a single dict
62
+ with the same keys. For each key, it recursively collates all entries in the list.
63
+ """
64
+ if not isinstance(vec, list):
65
+ raise TypeError("collate_list must take in a list")
66
+ elem = vec[0]
67
+ if torch.is_tensor(elem):
68
+ return torch.cat(vec)
69
+ elif isinstance(elem, list):
70
+ return [obj for sublist in vec for obj in sublist]
71
+ elif isinstance(elem, dict):
72
+ return {k: collate_list([d[k] for d in vec]) for k in elem}
73
+ else:
74
+ raise TypeError("Elements of the list to collate must be tensors or dicts.")