diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..723ef36f4e4f32c4560383aa5987c575a30c6535
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1 @@
+.idea
\ No newline at end of file
diff --git a/CLS2IDX.py b/CLS2IDX.py
new file mode 100644
index 0000000000000000000000000000000000000000..a48d89995ca97a4d0a8db99848d5262e564e9058
--- /dev/null
+++ b/CLS2IDX.py
@@ -0,0 +1,1000 @@
+CLS2IDX = {0: 'tench, Tinca tinca',
+ 1: 'goldfish, Carassius auratus',
+ 2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias',
+ 3: 'tiger shark, Galeocerdo cuvieri',
+ 4: 'hammerhead, hammerhead shark',
+ 5: 'electric ray, crampfish, numbfish, torpedo',
+ 6: 'stingray',
+ 7: 'cock',
+ 8: 'hen',
+ 9: 'ostrich, Struthio camelus',
+ 10: 'brambling, Fringilla montifringilla',
+ 11: 'goldfinch, Carduelis carduelis',
+ 12: 'house finch, linnet, Carpodacus mexicanus',
+ 13: 'junco, snowbird',
+ 14: 'indigo bunting, indigo finch, indigo bird, Passerina cyanea',
+ 15: 'robin, American robin, Turdus migratorius',
+ 16: 'bulbul',
+ 17: 'jay',
+ 18: 'magpie',
+ 19: 'chickadee',
+ 20: 'water ouzel, dipper',
+ 21: 'kite',
+ 22: 'bald eagle, American eagle, Haliaeetus leucocephalus',
+ 23: 'vulture',
+ 24: 'great grey owl, great gray owl, Strix nebulosa',
+ 25: 'European fire salamander, Salamandra salamandra',
+ 26: 'common newt, Triturus vulgaris',
+ 27: 'eft',
+ 28: 'spotted salamander, Ambystoma maculatum',
+ 29: 'axolotl, mud puppy, Ambystoma mexicanum',
+ 30: 'bullfrog, Rana catesbeiana',
+ 31: 'tree frog, tree-frog',
+ 32: 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui',
+ 33: 'loggerhead, loggerhead turtle, Caretta caretta',
+ 34: 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea',
+ 35: 'mud turtle',
+ 36: 'terrapin',
+ 37: 'box turtle, box tortoise',
+ 38: 'banded gecko',
+ 39: 'common iguana, iguana, Iguana iguana',
+ 40: 'American chameleon, anole, Anolis carolinensis',
+ 41: 'whiptail, whiptail lizard',
+ 42: 'agama',
+ 43: 'frilled lizard, Chlamydosaurus kingi',
+ 44: 'alligator lizard',
+ 45: 'Gila monster, Heloderma suspectum',
+ 46: 'green lizard, Lacerta viridis',
+ 47: 'African chameleon, Chamaeleo chamaeleon',
+ 48: 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis',
+ 49: 'African crocodile, Nile crocodile, Crocodylus niloticus',
+ 50: 'American alligator, Alligator mississipiensis',
+ 51: 'triceratops',
+ 52: 'thunder snake, worm snake, Carphophis amoenus',
+ 53: 'ringneck snake, ring-necked snake, ring snake',
+ 54: 'hognose snake, puff adder, sand viper',
+ 55: 'green snake, grass snake',
+ 56: 'king snake, kingsnake',
+ 57: 'garter snake, grass snake',
+ 58: 'water snake',
+ 59: 'vine snake',
+ 60: 'night snake, Hypsiglena torquata',
+ 61: 'boa constrictor, Constrictor constrictor',
+ 62: 'rock python, rock snake, Python sebae',
+ 63: 'Indian cobra, Naja naja',
+ 64: 'green mamba',
+ 65: 'sea snake',
+ 66: 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus',
+ 67: 'diamondback, diamondback rattlesnake, Crotalus adamanteus',
+ 68: 'sidewinder, horned rattlesnake, Crotalus cerastes',
+ 69: 'trilobite',
+ 70: 'harvestman, daddy longlegs, Phalangium opilio',
+ 71: 'scorpion',
+ 72: 'black and gold garden spider, Argiope aurantia',
+ 73: 'barn spider, Araneus cavaticus',
+ 74: 'garden spider, Aranea diademata',
+ 75: 'black widow, Latrodectus mactans',
+ 76: 'tarantula',
+ 77: 'wolf spider, hunting spider',
+ 78: 'tick',
+ 79: 'centipede',
+ 80: 'black grouse',
+ 81: 'ptarmigan',
+ 82: 'ruffed grouse, partridge, Bonasa umbellus',
+ 83: 'prairie chicken, prairie grouse, prairie fowl',
+ 84: 'peacock',
+ 85: 'quail',
+ 86: 'partridge',
+ 87: 'African grey, African gray, Psittacus erithacus',
+ 88: 'macaw',
+ 89: 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita',
+ 90: 'lorikeet',
+ 91: 'coucal',
+ 92: 'bee eater',
+ 93: 'hornbill',
+ 94: 'hummingbird',
+ 95: 'jacamar',
+ 96: 'toucan',
+ 97: 'drake',
+ 98: 'red-breasted merganser, Mergus serrator',
+ 99: 'goose',
+ 100: 'black swan, Cygnus atratus',
+ 101: 'tusker',
+ 102: 'echidna, spiny anteater, anteater',
+ 103: 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus',
+ 104: 'wallaby, brush kangaroo',
+ 105: 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus',
+ 106: 'wombat',
+ 107: 'jellyfish',
+ 108: 'sea anemone, anemone',
+ 109: 'brain coral',
+ 110: 'flatworm, platyhelminth',
+ 111: 'nematode, nematode worm, roundworm',
+ 112: 'conch',
+ 113: 'snail',
+ 114: 'slug',
+ 115: 'sea slug, nudibranch',
+ 116: 'chiton, coat-of-mail shell, sea cradle, polyplacophore',
+ 117: 'chambered nautilus, pearly nautilus, nautilus',
+ 118: 'Dungeness crab, Cancer magister',
+ 119: 'rock crab, Cancer irroratus',
+ 120: 'fiddler crab',
+ 121: 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica',
+ 122: 'American lobster, Northern lobster, Maine lobster, Homarus americanus',
+ 123: 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish',
+ 124: 'crayfish, crawfish, crawdad, crawdaddy',
+ 125: 'hermit crab',
+ 126: 'isopod',
+ 127: 'white stork, Ciconia ciconia',
+ 128: 'black stork, Ciconia nigra',
+ 129: 'spoonbill',
+ 130: 'flamingo',
+ 131: 'little blue heron, Egretta caerulea',
+ 132: 'American egret, great white heron, Egretta albus',
+ 133: 'bittern',
+ 134: 'crane',
+ 135: 'limpkin, Aramus pictus',
+ 136: 'European gallinule, Porphyrio porphyrio',
+ 137: 'American coot, marsh hen, mud hen, water hen, Fulica americana',
+ 138: 'bustard',
+ 139: 'ruddy turnstone, Arenaria interpres',
+ 140: 'red-backed sandpiper, dunlin, Erolia alpina',
+ 141: 'redshank, Tringa totanus',
+ 142: 'dowitcher',
+ 143: 'oystercatcher, oyster catcher',
+ 144: 'pelican',
+ 145: 'king penguin, Aptenodytes patagonica',
+ 146: 'albatross, mollymawk',
+ 147: 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus',
+ 148: 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca',
+ 149: 'dugong, Dugong dugon',
+ 150: 'sea lion',
+ 151: 'Chihuahua',
+ 152: 'Japanese spaniel',
+ 153: 'Maltese dog, Maltese terrier, Maltese',
+ 154: 'Pekinese, Pekingese, Peke',
+ 155: 'Shih-Tzu',
+ 156: 'Blenheim spaniel',
+ 157: 'papillon',
+ 158: 'toy terrier',
+ 159: 'Rhodesian ridgeback',
+ 160: 'Afghan hound, Afghan',
+ 161: 'basset, basset hound',
+ 162: 'beagle',
+ 163: 'bloodhound, sleuthhound',
+ 164: 'bluetick',
+ 165: 'black-and-tan coonhound',
+ 166: 'Walker hound, Walker foxhound',
+ 167: 'English foxhound',
+ 168: 'redbone',
+ 169: 'borzoi, Russian wolfhound',
+ 170: 'Irish wolfhound',
+ 171: 'Italian greyhound',
+ 172: 'whippet',
+ 173: 'Ibizan hound, Ibizan Podenco',
+ 174: 'Norwegian elkhound, elkhound',
+ 175: 'otterhound, otter hound',
+ 176: 'Saluki, gazelle hound',
+ 177: 'Scottish deerhound, deerhound',
+ 178: 'Weimaraner',
+ 179: 'Staffordshire bullterrier, Staffordshire bull terrier',
+ 180: 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier',
+ 181: 'Bedlington terrier',
+ 182: 'Border terrier',
+ 183: 'Kerry blue terrier',
+ 184: 'Irish terrier',
+ 185: 'Norfolk terrier',
+ 186: 'Norwich terrier',
+ 187: 'Yorkshire terrier',
+ 188: 'wire-haired fox terrier',
+ 189: 'Lakeland terrier',
+ 190: 'Sealyham terrier, Sealyham',
+ 191: 'Airedale, Airedale terrier',
+ 192: 'cairn, cairn terrier',
+ 193: 'Australian terrier',
+ 194: 'Dandie Dinmont, Dandie Dinmont terrier',
+ 195: 'Boston bull, Boston terrier',
+ 196: 'miniature schnauzer',
+ 197: 'giant schnauzer',
+ 198: 'standard schnauzer',
+ 199: 'Scotch terrier, Scottish terrier, Scottie',
+ 200: 'Tibetan terrier, chrysanthemum dog',
+ 201: 'silky terrier, Sydney silky',
+ 202: 'soft-coated wheaten terrier',
+ 203: 'West Highland white terrier',
+ 204: 'Lhasa, Lhasa apso',
+ 205: 'flat-coated retriever',
+ 206: 'curly-coated retriever',
+ 207: 'golden retriever',
+ 208: 'Labrador retriever',
+ 209: 'Chesapeake Bay retriever',
+ 210: 'German short-haired pointer',
+ 211: 'vizsla, Hungarian pointer',
+ 212: 'English setter',
+ 213: 'Irish setter, red setter',
+ 214: 'Gordon setter',
+ 215: 'Brittany spaniel',
+ 216: 'clumber, clumber spaniel',
+ 217: 'English springer, English springer spaniel',
+ 218: 'Welsh springer spaniel',
+ 219: 'cocker spaniel, English cocker spaniel, cocker',
+ 220: 'Sussex spaniel',
+ 221: 'Irish water spaniel',
+ 222: 'kuvasz',
+ 223: 'schipperke',
+ 224: 'groenendael',
+ 225: 'malinois',
+ 226: 'briard',
+ 227: 'kelpie',
+ 228: 'komondor',
+ 229: 'Old English sheepdog, bobtail',
+ 230: 'Shetland sheepdog, Shetland sheep dog, Shetland',
+ 231: 'collie',
+ 232: 'Border collie',
+ 233: 'Bouvier des Flandres, Bouviers des Flandres',
+ 234: 'Rottweiler',
+ 235: 'German shepherd, German shepherd dog, German police dog, alsatian',
+ 236: 'Doberman, Doberman pinscher',
+ 237: 'miniature pinscher',
+ 238: 'Greater Swiss Mountain dog',
+ 239: 'Bernese mountain dog',
+ 240: 'Appenzeller',
+ 241: 'EntleBucher',
+ 242: 'boxer',
+ 243: 'bull mastiff',
+ 244: 'Tibetan mastiff',
+ 245: 'French bulldog',
+ 246: 'Great Dane',
+ 247: 'Saint Bernard, St Bernard',
+ 248: 'Eskimo dog, husky',
+ 249: 'malamute, malemute, Alaskan malamute',
+ 250: 'Siberian husky',
+ 251: 'dalmatian, coach dog, carriage dog',
+ 252: 'affenpinscher, monkey pinscher, monkey dog',
+ 253: 'basenji',
+ 254: 'pug, pug-dog',
+ 255: 'Leonberg',
+ 256: 'Newfoundland, Newfoundland dog',
+ 257: 'Great Pyrenees',
+ 258: 'Samoyed, Samoyede',
+ 259: 'Pomeranian',
+ 260: 'chow, chow chow',
+ 261: 'keeshond',
+ 262: 'Brabancon griffon',
+ 263: 'Pembroke, Pembroke Welsh corgi',
+ 264: 'Cardigan, Cardigan Welsh corgi',
+ 265: 'toy poodle',
+ 266: 'miniature poodle',
+ 267: 'standard poodle',
+ 268: 'Mexican hairless',
+ 269: 'timber wolf, grey wolf, gray wolf, Canis lupus',
+ 270: 'white wolf, Arctic wolf, Canis lupus tundrarum',
+ 271: 'red wolf, maned wolf, Canis rufus, Canis niger',
+ 272: 'coyote, prairie wolf, brush wolf, Canis latrans',
+ 273: 'dingo, warrigal, warragal, Canis dingo',
+ 274: 'dhole, Cuon alpinus',
+ 275: 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus',
+ 276: 'hyena, hyaena',
+ 277: 'red fox, Vulpes vulpes',
+ 278: 'kit fox, Vulpes macrotis',
+ 279: 'Arctic fox, white fox, Alopex lagopus',
+ 280: 'grey fox, gray fox, Urocyon cinereoargenteus',
+ 281: 'tabby, tabby cat',
+ 282: 'tiger cat',
+ 283: 'Persian cat',
+ 284: 'Siamese cat, Siamese',
+ 285: 'Egyptian cat',
+ 286: 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor',
+ 287: 'lynx, catamount',
+ 288: 'leopard, Panthera pardus',
+ 289: 'snow leopard, ounce, Panthera uncia',
+ 290: 'jaguar, panther, Panthera onca, Felis onca',
+ 291: 'lion, king of beasts, Panthera leo',
+ 292: 'tiger, Panthera tigris',
+ 293: 'cheetah, chetah, Acinonyx jubatus',
+ 294: 'brown bear, bruin, Ursus arctos',
+ 295: 'American black bear, black bear, Ursus americanus, Euarctos americanus',
+ 296: 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus',
+ 297: 'sloth bear, Melursus ursinus, Ursus ursinus',
+ 298: 'mongoose',
+ 299: 'meerkat, mierkat',
+ 300: 'tiger beetle',
+ 301: 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle',
+ 302: 'ground beetle, carabid beetle',
+ 303: 'long-horned beetle, longicorn, longicorn beetle',
+ 304: 'leaf beetle, chrysomelid',
+ 305: 'dung beetle',
+ 306: 'rhinoceros beetle',
+ 307: 'weevil',
+ 308: 'fly',
+ 309: 'bee',
+ 310: 'ant, emmet, pismire',
+ 311: 'grasshopper, hopper',
+ 312: 'cricket',
+ 313: 'walking stick, walkingstick, stick insect',
+ 314: 'cockroach, roach',
+ 315: 'mantis, mantid',
+ 316: 'cicada, cicala',
+ 317: 'leafhopper',
+ 318: 'lacewing, lacewing fly',
+ 319: "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
+ 320: 'damselfly',
+ 321: 'admiral',
+ 322: 'ringlet, ringlet butterfly',
+ 323: 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus',
+ 324: 'cabbage butterfly',
+ 325: 'sulphur butterfly, sulfur butterfly',
+ 326: 'lycaenid, lycaenid butterfly',
+ 327: 'starfish, sea star',
+ 328: 'sea urchin',
+ 329: 'sea cucumber, holothurian',
+ 330: 'wood rabbit, cottontail, cottontail rabbit',
+ 331: 'hare',
+ 332: 'Angora, Angora rabbit',
+ 333: 'hamster',
+ 334: 'porcupine, hedgehog',
+ 335: 'fox squirrel, eastern fox squirrel, Sciurus niger',
+ 336: 'marmot',
+ 337: 'beaver',
+ 338: 'guinea pig, Cavia cobaya',
+ 339: 'sorrel',
+ 340: 'zebra',
+ 341: 'hog, pig, grunter, squealer, Sus scrofa',
+ 342: 'wild boar, boar, Sus scrofa',
+ 343: 'warthog',
+ 344: 'hippopotamus, hippo, river horse, Hippopotamus amphibius',
+ 345: 'ox',
+ 346: 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis',
+ 347: 'bison',
+ 348: 'ram, tup',
+ 349: 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis',
+ 350: 'ibex, Capra ibex',
+ 351: 'hartebeest',
+ 352: 'impala, Aepyceros melampus',
+ 353: 'gazelle',
+ 354: 'Arabian camel, dromedary, Camelus dromedarius',
+ 355: 'llama',
+ 356: 'weasel',
+ 357: 'mink',
+ 358: 'polecat, fitch, foulmart, foumart, Mustela putorius',
+ 359: 'black-footed ferret, ferret, Mustela nigripes',
+ 360: 'otter',
+ 361: 'skunk, polecat, wood pussy',
+ 362: 'badger',
+ 363: 'armadillo',
+ 364: 'three-toed sloth, ai, Bradypus tridactylus',
+ 365: 'orangutan, orang, orangutang, Pongo pygmaeus',
+ 366: 'gorilla, Gorilla gorilla',
+ 367: 'chimpanzee, chimp, Pan troglodytes',
+ 368: 'gibbon, Hylobates lar',
+ 369: 'siamang, Hylobates syndactylus, Symphalangus syndactylus',
+ 370: 'guenon, guenon monkey',
+ 371: 'patas, hussar monkey, Erythrocebus patas',
+ 372: 'baboon',
+ 373: 'macaque',
+ 374: 'langur',
+ 375: 'colobus, colobus monkey',
+ 376: 'proboscis monkey, Nasalis larvatus',
+ 377: 'marmoset',
+ 378: 'capuchin, ringtail, Cebus capucinus',
+ 379: 'howler monkey, howler',
+ 380: 'titi, titi monkey',
+ 381: 'spider monkey, Ateles geoffroyi',
+ 382: 'squirrel monkey, Saimiri sciureus',
+ 383: 'Madagascar cat, ring-tailed lemur, Lemur catta',
+ 384: 'indri, indris, Indri indri, Indri brevicaudatus',
+ 385: 'Indian elephant, Elephas maximus',
+ 386: 'African elephant, Loxodonta africana',
+ 387: 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens',
+ 388: 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca',
+ 389: 'barracouta, snoek',
+ 390: 'eel',
+ 391: 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch',
+ 392: 'rock beauty, Holocanthus tricolor',
+ 393: 'anemone fish',
+ 394: 'sturgeon',
+ 395: 'gar, garfish, garpike, billfish, Lepisosteus osseus',
+ 396: 'lionfish',
+ 397: 'puffer, pufferfish, blowfish, globefish',
+ 398: 'abacus',
+ 399: 'abaya',
+ 400: "academic gown, academic robe, judge's robe",
+ 401: 'accordion, piano accordion, squeeze box',
+ 402: 'acoustic guitar',
+ 403: 'aircraft carrier, carrier, flattop, attack aircraft carrier',
+ 404: 'airliner',
+ 405: 'airship, dirigible',
+ 406: 'altar',
+ 407: 'ambulance',
+ 408: 'amphibian, amphibious vehicle',
+ 409: 'analog clock',
+ 410: 'apiary, bee house',
+ 411: 'apron',
+ 412: 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin',
+ 413: 'assault rifle, assault gun',
+ 414: 'backpack, back pack, knapsack, packsack, rucksack, haversack',
+ 415: 'bakery, bakeshop, bakehouse',
+ 416: 'balance beam, beam',
+ 417: 'balloon',
+ 418: 'ballpoint, ballpoint pen, ballpen, Biro',
+ 419: 'Band Aid',
+ 420: 'banjo',
+ 421: 'bannister, banister, balustrade, balusters, handrail',
+ 422: 'barbell',
+ 423: 'barber chair',
+ 424: 'barbershop',
+ 425: 'barn',
+ 426: 'barometer',
+ 427: 'barrel, cask',
+ 428: 'barrow, garden cart, lawn cart, wheelbarrow',
+ 429: 'baseball',
+ 430: 'basketball',
+ 431: 'bassinet',
+ 432: 'bassoon',
+ 433: 'bathing cap, swimming cap',
+ 434: 'bath towel',
+ 435: 'bathtub, bathing tub, bath, tub',
+ 436: 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon',
+ 437: 'beacon, lighthouse, beacon light, pharos',
+ 438: 'beaker',
+ 439: 'bearskin, busby, shako',
+ 440: 'beer bottle',
+ 441: 'beer glass',
+ 442: 'bell cote, bell cot',
+ 443: 'bib',
+ 444: 'bicycle-built-for-two, tandem bicycle, tandem',
+ 445: 'bikini, two-piece',
+ 446: 'binder, ring-binder',
+ 447: 'binoculars, field glasses, opera glasses',
+ 448: 'birdhouse',
+ 449: 'boathouse',
+ 450: 'bobsled, bobsleigh, bob',
+ 451: 'bolo tie, bolo, bola tie, bola',
+ 452: 'bonnet, poke bonnet',
+ 453: 'bookcase',
+ 454: 'bookshop, bookstore, bookstall',
+ 455: 'bottlecap',
+ 456: 'bow',
+ 457: 'bow tie, bow-tie, bowtie',
+ 458: 'brass, memorial tablet, plaque',
+ 459: 'brassiere, bra, bandeau',
+ 460: 'breakwater, groin, groyne, mole, bulwark, seawall, jetty',
+ 461: 'breastplate, aegis, egis',
+ 462: 'broom',
+ 463: 'bucket, pail',
+ 464: 'buckle',
+ 465: 'bulletproof vest',
+ 466: 'bullet train, bullet',
+ 467: 'butcher shop, meat market',
+ 468: 'cab, hack, taxi, taxicab',
+ 469: 'caldron, cauldron',
+ 470: 'candle, taper, wax light',
+ 471: 'cannon',
+ 472: 'canoe',
+ 473: 'can opener, tin opener',
+ 474: 'cardigan',
+ 475: 'car mirror',
+ 476: 'carousel, carrousel, merry-go-round, roundabout, whirligig',
+ 477: "carpenter's kit, tool kit",
+ 478: 'carton',
+ 479: 'car wheel',
+ 480: 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM',
+ 481: 'cassette',
+ 482: 'cassette player',
+ 483: 'castle',
+ 484: 'catamaran',
+ 485: 'CD player',
+ 486: 'cello, violoncello',
+ 487: 'cellular telephone, cellular phone, cellphone, cell, mobile phone',
+ 488: 'chain',
+ 489: 'chainlink fence',
+ 490: 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour',
+ 491: 'chain saw, chainsaw',
+ 492: 'chest',
+ 493: 'chiffonier, commode',
+ 494: 'chime, bell, gong',
+ 495: 'china cabinet, china closet',
+ 496: 'Christmas stocking',
+ 497: 'church, church building',
+ 498: 'cinema, movie theater, movie theatre, movie house, picture palace',
+ 499: 'cleaver, meat cleaver, chopper',
+ 500: 'cliff dwelling',
+ 501: 'cloak',
+ 502: 'clog, geta, patten, sabot',
+ 503: 'cocktail shaker',
+ 504: 'coffee mug',
+ 505: 'coffeepot',
+ 506: 'coil, spiral, volute, whorl, helix',
+ 507: 'combination lock',
+ 508: 'computer keyboard, keypad',
+ 509: 'confectionery, confectionary, candy store',
+ 510: 'container ship, containership, container vessel',
+ 511: 'convertible',
+ 512: 'corkscrew, bottle screw',
+ 513: 'cornet, horn, trumpet, trump',
+ 514: 'cowboy boot',
+ 515: 'cowboy hat, ten-gallon hat',
+ 516: 'cradle',
+ 517: 'crane',
+ 518: 'crash helmet',
+ 519: 'crate',
+ 520: 'crib, cot',
+ 521: 'Crock Pot',
+ 522: 'croquet ball',
+ 523: 'crutch',
+ 524: 'cuirass',
+ 525: 'dam, dike, dyke',
+ 526: 'desk',
+ 527: 'desktop computer',
+ 528: 'dial telephone, dial phone',
+ 529: 'diaper, nappy, napkin',
+ 530: 'digital clock',
+ 531: 'digital watch',
+ 532: 'dining table, board',
+ 533: 'dishrag, dishcloth',
+ 534: 'dishwasher, dish washer, dishwashing machine',
+ 535: 'disk brake, disc brake',
+ 536: 'dock, dockage, docking facility',
+ 537: 'dogsled, dog sled, dog sleigh',
+ 538: 'dome',
+ 539: 'doormat, welcome mat',
+ 540: 'drilling platform, offshore rig',
+ 541: 'drum, membranophone, tympan',
+ 542: 'drumstick',
+ 543: 'dumbbell',
+ 544: 'Dutch oven',
+ 545: 'electric fan, blower',
+ 546: 'electric guitar',
+ 547: 'electric locomotive',
+ 548: 'entertainment center',
+ 549: 'envelope',
+ 550: 'espresso maker',
+ 551: 'face powder',
+ 552: 'feather boa, boa',
+ 553: 'file, file cabinet, filing cabinet',
+ 554: 'fireboat',
+ 555: 'fire engine, fire truck',
+ 556: 'fire screen, fireguard',
+ 557: 'flagpole, flagstaff',
+ 558: 'flute, transverse flute',
+ 559: 'folding chair',
+ 560: 'football helmet',
+ 561: 'forklift',
+ 562: 'fountain',
+ 563: 'fountain pen',
+ 564: 'four-poster',
+ 565: 'freight car',
+ 566: 'French horn, horn',
+ 567: 'frying pan, frypan, skillet',
+ 568: 'fur coat',
+ 569: 'garbage truck, dustcart',
+ 570: 'gasmask, respirator, gas helmet',
+ 571: 'gas pump, gasoline pump, petrol pump, island dispenser',
+ 572: 'goblet',
+ 573: 'go-kart',
+ 574: 'golf ball',
+ 575: 'golfcart, golf cart',
+ 576: 'gondola',
+ 577: 'gong, tam-tam',
+ 578: 'gown',
+ 579: 'grand piano, grand',
+ 580: 'greenhouse, nursery, glasshouse',
+ 581: 'grille, radiator grille',
+ 582: 'grocery store, grocery, food market, market',
+ 583: 'guillotine',
+ 584: 'hair slide',
+ 585: 'hair spray',
+ 586: 'half track',
+ 587: 'hammer',
+ 588: 'hamper',
+ 589: 'hand blower, blow dryer, blow drier, hair dryer, hair drier',
+ 590: 'hand-held computer, hand-held microcomputer',
+ 591: 'handkerchief, hankie, hanky, hankey',
+ 592: 'hard disc, hard disk, fixed disk',
+ 593: 'harmonica, mouth organ, harp, mouth harp',
+ 594: 'harp',
+ 595: 'harvester, reaper',
+ 596: 'hatchet',
+ 597: 'holster',
+ 598: 'home theater, home theatre',
+ 599: 'honeycomb',
+ 600: 'hook, claw',
+ 601: 'hoopskirt, crinoline',
+ 602: 'horizontal bar, high bar',
+ 603: 'horse cart, horse-cart',
+ 604: 'hourglass',
+ 605: 'iPod',
+ 606: 'iron, smoothing iron',
+ 607: "jack-o'-lantern",
+ 608: 'jean, blue jean, denim',
+ 609: 'jeep, landrover',
+ 610: 'jersey, T-shirt, tee shirt',
+ 611: 'jigsaw puzzle',
+ 612: 'jinrikisha, ricksha, rickshaw',
+ 613: 'joystick',
+ 614: 'kimono',
+ 615: 'knee pad',
+ 616: 'knot',
+ 617: 'lab coat, laboratory coat',
+ 618: 'ladle',
+ 619: 'lampshade, lamp shade',
+ 620: 'laptop, laptop computer',
+ 621: 'lawn mower, mower',
+ 622: 'lens cap, lens cover',
+ 623: 'letter opener, paper knife, paperknife',
+ 624: 'library',
+ 625: 'lifeboat',
+ 626: 'lighter, light, igniter, ignitor',
+ 627: 'limousine, limo',
+ 628: 'liner, ocean liner',
+ 629: 'lipstick, lip rouge',
+ 630: 'Loafer',
+ 631: 'lotion',
+ 632: 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system',
+ 633: "loupe, jeweler's loupe",
+ 634: 'lumbermill, sawmill',
+ 635: 'magnetic compass',
+ 636: 'mailbag, postbag',
+ 637: 'mailbox, letter box',
+ 638: 'maillot',
+ 639: 'maillot, tank suit',
+ 640: 'manhole cover',
+ 641: 'maraca',
+ 642: 'marimba, xylophone',
+ 643: 'mask',
+ 644: 'matchstick',
+ 645: 'maypole',
+ 646: 'maze, labyrinth',
+ 647: 'measuring cup',
+ 648: 'medicine chest, medicine cabinet',
+ 649: 'megalith, megalithic structure',
+ 650: 'microphone, mike',
+ 651: 'microwave, microwave oven',
+ 652: 'military uniform',
+ 653: 'milk can',
+ 654: 'minibus',
+ 655: 'miniskirt, mini',
+ 656: 'minivan',
+ 657: 'missile',
+ 658: 'mitten',
+ 659: 'mixing bowl',
+ 660: 'mobile home, manufactured home',
+ 661: 'Model T',
+ 662: 'modem',
+ 663: 'monastery',
+ 664: 'monitor',
+ 665: 'moped',
+ 666: 'mortar',
+ 667: 'mortarboard',
+ 668: 'mosque',
+ 669: 'mosquito net',
+ 670: 'motor scooter, scooter',
+ 671: 'mountain bike, all-terrain bike, off-roader',
+ 672: 'mountain tent',
+ 673: 'mouse, computer mouse',
+ 674: 'mousetrap',
+ 675: 'moving van',
+ 676: 'muzzle',
+ 677: 'nail',
+ 678: 'neck brace',
+ 679: 'necklace',
+ 680: 'nipple',
+ 681: 'notebook, notebook computer',
+ 682: 'obelisk',
+ 683: 'oboe, hautboy, hautbois',
+ 684: 'ocarina, sweet potato',
+ 685: 'odometer, hodometer, mileometer, milometer',
+ 686: 'oil filter',
+ 687: 'organ, pipe organ',
+ 688: 'oscilloscope, scope, cathode-ray oscilloscope, CRO',
+ 689: 'overskirt',
+ 690: 'oxcart',
+ 691: 'oxygen mask',
+ 692: 'packet',
+ 693: 'paddle, boat paddle',
+ 694: 'paddlewheel, paddle wheel',
+ 695: 'padlock',
+ 696: 'paintbrush',
+ 697: "pajama, pyjama, pj's, jammies",
+ 698: 'palace',
+ 699: 'panpipe, pandean pipe, syrinx',
+ 700: 'paper towel',
+ 701: 'parachute, chute',
+ 702: 'parallel bars, bars',
+ 703: 'park bench',
+ 704: 'parking meter',
+ 705: 'passenger car, coach, carriage',
+ 706: 'patio, terrace',
+ 707: 'pay-phone, pay-station',
+ 708: 'pedestal, plinth, footstall',
+ 709: 'pencil box, pencil case',
+ 710: 'pencil sharpener',
+ 711: 'perfume, essence',
+ 712: 'Petri dish',
+ 713: 'photocopier',
+ 714: 'pick, plectrum, plectron',
+ 715: 'pickelhaube',
+ 716: 'picket fence, paling',
+ 717: 'pickup, pickup truck',
+ 718: 'pier',
+ 719: 'piggy bank, penny bank',
+ 720: 'pill bottle',
+ 721: 'pillow',
+ 722: 'ping-pong ball',
+ 723: 'pinwheel',
+ 724: 'pirate, pirate ship',
+ 725: 'pitcher, ewer',
+ 726: "plane, carpenter's plane, woodworking plane",
+ 727: 'planetarium',
+ 728: 'plastic bag',
+ 729: 'plate rack',
+ 730: 'plow, plough',
+ 731: "plunger, plumber's helper",
+ 732: 'Polaroid camera, Polaroid Land camera',
+ 733: 'pole',
+ 734: 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria',
+ 735: 'poncho',
+ 736: 'pool table, billiard table, snooker table',
+ 737: 'pop bottle, soda bottle',
+ 738: 'pot, flowerpot',
+ 739: "potter's wheel",
+ 740: 'power drill',
+ 741: 'prayer rug, prayer mat',
+ 742: 'printer',
+ 743: 'prison, prison house',
+ 744: 'projectile, missile',
+ 745: 'projector',
+ 746: 'puck, hockey puck',
+ 747: 'punching bag, punch bag, punching ball, punchball',
+ 748: 'purse',
+ 749: 'quill, quill pen',
+ 750: 'quilt, comforter, comfort, puff',
+ 751: 'racer, race car, racing car',
+ 752: 'racket, racquet',
+ 753: 'radiator',
+ 754: 'radio, wireless',
+ 755: 'radio telescope, radio reflector',
+ 756: 'rain barrel',
+ 757: 'recreational vehicle, RV, R.V.',
+ 758: 'reel',
+ 759: 'reflex camera',
+ 760: 'refrigerator, icebox',
+ 761: 'remote control, remote',
+ 762: 'restaurant, eating house, eating place, eatery',
+ 763: 'revolver, six-gun, six-shooter',
+ 764: 'rifle',
+ 765: 'rocking chair, rocker',
+ 766: 'rotisserie',
+ 767: 'rubber eraser, rubber, pencil eraser',
+ 768: 'rugby ball',
+ 769: 'rule, ruler',
+ 770: 'running shoe',
+ 771: 'safe',
+ 772: 'safety pin',
+ 773: 'saltshaker, salt shaker',
+ 774: 'sandal',
+ 775: 'sarong',
+ 776: 'sax, saxophone',
+ 777: 'scabbard',
+ 778: 'scale, weighing machine',
+ 779: 'school bus',
+ 780: 'schooner',
+ 781: 'scoreboard',
+ 782: 'screen, CRT screen',
+ 783: 'screw',
+ 784: 'screwdriver',
+ 785: 'seat belt, seatbelt',
+ 786: 'sewing machine',
+ 787: 'shield, buckler',
+ 788: 'shoe shop, shoe-shop, shoe store',
+ 789: 'shoji',
+ 790: 'shopping basket',
+ 791: 'shopping cart',
+ 792: 'shovel',
+ 793: 'shower cap',
+ 794: 'shower curtain',
+ 795: 'ski',
+ 796: 'ski mask',
+ 797: 'sleeping bag',
+ 798: 'slide rule, slipstick',
+ 799: 'sliding door',
+ 800: 'slot, one-armed bandit',
+ 801: 'snorkel',
+ 802: 'snowmobile',
+ 803: 'snowplow, snowplough',
+ 804: 'soap dispenser',
+ 805: 'soccer ball',
+ 806: 'sock',
+ 807: 'solar dish, solar collector, solar furnace',
+ 808: 'sombrero',
+ 809: 'soup bowl',
+ 810: 'space bar',
+ 811: 'space heater',
+ 812: 'space shuttle',
+ 813: 'spatula',
+ 814: 'speedboat',
+ 815: "spider web, spider's web",
+ 816: 'spindle',
+ 817: 'sports car, sport car',
+ 818: 'spotlight, spot',
+ 819: 'stage',
+ 820: 'steam locomotive',
+ 821: 'steel arch bridge',
+ 822: 'steel drum',
+ 823: 'stethoscope',
+ 824: 'stole',
+ 825: 'stone wall',
+ 826: 'stopwatch, stop watch',
+ 827: 'stove',
+ 828: 'strainer',
+ 829: 'streetcar, tram, tramcar, trolley, trolley car',
+ 830: 'stretcher',
+ 831: 'studio couch, day bed',
+ 832: 'stupa, tope',
+ 833: 'submarine, pigboat, sub, U-boat',
+ 834: 'suit, suit of clothes',
+ 835: 'sundial',
+ 836: 'sunglass',
+ 837: 'sunglasses, dark glasses, shades',
+ 838: 'sunscreen, sunblock, sun blocker',
+ 839: 'suspension bridge',
+ 840: 'swab, swob, mop',
+ 841: 'sweatshirt',
+ 842: 'swimming trunks, bathing trunks',
+ 843: 'swing',
+ 844: 'switch, electric switch, electrical switch',
+ 845: 'syringe',
+ 846: 'table lamp',
+ 847: 'tank, army tank, armored combat vehicle, armoured combat vehicle',
+ 848: 'tape player',
+ 849: 'teapot',
+ 850: 'teddy, teddy bear',
+ 851: 'television, television system',
+ 852: 'tennis ball',
+ 853: 'thatch, thatched roof',
+ 854: 'theater curtain, theatre curtain',
+ 855: 'thimble',
+ 856: 'thresher, thrasher, threshing machine',
+ 857: 'throne',
+ 858: 'tile roof',
+ 859: 'toaster',
+ 860: 'tobacco shop, tobacconist shop, tobacconist',
+ 861: 'toilet seat',
+ 862: 'torch',
+ 863: 'totem pole',
+ 864: 'tow truck, tow car, wrecker',
+ 865: 'toyshop',
+ 866: 'tractor',
+ 867: 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi',
+ 868: 'tray',
+ 869: 'trench coat',
+ 870: 'tricycle, trike, velocipede',
+ 871: 'trimaran',
+ 872: 'tripod',
+ 873: 'triumphal arch',
+ 874: 'trolleybus, trolley coach, trackless trolley',
+ 875: 'trombone',
+ 876: 'tub, vat',
+ 877: 'turnstile',
+ 878: 'typewriter keyboard',
+ 879: 'umbrella',
+ 880: 'unicycle, monocycle',
+ 881: 'upright, upright piano',
+ 882: 'vacuum, vacuum cleaner',
+ 883: 'vase',
+ 884: 'vault',
+ 885: 'velvet',
+ 886: 'vending machine',
+ 887: 'vestment',
+ 888: 'viaduct',
+ 889: 'violin, fiddle',
+ 890: 'volleyball',
+ 891: 'waffle iron',
+ 892: 'wall clock',
+ 893: 'wallet, billfold, notecase, pocketbook',
+ 894: 'wardrobe, closet, press',
+ 895: 'warplane, military plane',
+ 896: 'washbasin, handbasin, washbowl, lavabo, wash-hand basin',
+ 897: 'washer, automatic washer, washing machine',
+ 898: 'water bottle',
+ 899: 'water jug',
+ 900: 'water tower',
+ 901: 'whiskey jug',
+ 902: 'whistle',
+ 903: 'wig',
+ 904: 'window screen',
+ 905: 'window shade',
+ 906: 'Windsor tie',
+ 907: 'wine bottle',
+ 908: 'wing',
+ 909: 'wok',
+ 910: 'wooden spoon',
+ 911: 'wool, woolen, woollen',
+ 912: 'worm fence, snake fence, snake-rail fence, Virginia fence',
+ 913: 'wreck',
+ 914: 'yawl',
+ 915: 'yurt',
+ 916: 'web site, website, internet site, site',
+ 917: 'comic book',
+ 918: 'crossword puzzle, crossword',
+ 919: 'street sign',
+ 920: 'traffic light, traffic signal, stoplight',
+ 921: 'book jacket, dust cover, dust jacket, dust wrapper',
+ 922: 'menu',
+ 923: 'plate',
+ 924: 'guacamole',
+ 925: 'consomme',
+ 926: 'hot pot, hotpot',
+ 927: 'trifle',
+ 928: 'ice cream, icecream',
+ 929: 'ice lolly, lolly, lollipop, popsicle',
+ 930: 'French loaf',
+ 931: 'bagel, beigel',
+ 932: 'pretzel',
+ 933: 'cheeseburger',
+ 934: 'hotdog, hot dog, red hot',
+ 935: 'mashed potato',
+ 936: 'head cabbage',
+ 937: 'broccoli',
+ 938: 'cauliflower',
+ 939: 'zucchini, courgette',
+ 940: 'spaghetti squash',
+ 941: 'acorn squash',
+ 942: 'butternut squash',
+ 943: 'cucumber, cuke',
+ 944: 'artichoke, globe artichoke',
+ 945: 'bell pepper',
+ 946: 'cardoon',
+ 947: 'mushroom',
+ 948: 'Granny Smith',
+ 949: 'strawberry',
+ 950: 'orange',
+ 951: 'lemon',
+ 952: 'fig',
+ 953: 'pineapple, ananas',
+ 954: 'banana',
+ 955: 'jackfruit, jak, jack',
+ 956: 'custard apple',
+ 957: 'pomegranate',
+ 958: 'hay',
+ 959: 'carbonara',
+ 960: 'chocolate sauce, chocolate syrup',
+ 961: 'dough',
+ 962: 'meat loaf, meatloaf',
+ 963: 'pizza, pizza pie',
+ 964: 'potpie',
+ 965: 'burrito',
+ 966: 'red wine',
+ 967: 'espresso',
+ 968: 'cup',
+ 969: 'eggnog',
+ 970: 'alp',
+ 971: 'bubble',
+ 972: 'cliff, drop, drop-off',
+ 973: 'coral reef',
+ 974: 'geyser',
+ 975: 'lakeside, lakeshore',
+ 976: 'promontory, headland, head, foreland',
+ 977: 'sandbar, sand bar',
+ 978: 'seashore, coast, seacoast, sea-coast',
+ 979: 'valley, vale',
+ 980: 'volcano',
+ 981: 'ballplayer, baseball player',
+ 982: 'groom, bridegroom',
+ 983: 'scuba diver',
+ 984: 'rapeseed',
+ 985: 'daisy',
+ 986: "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
+ 987: 'corn',
+ 988: 'acorn',
+ 989: 'hip, rose hip, rosehip',
+ 990: 'buckeye, horse chestnut, conker',
+ 991: 'coral fungus',
+ 992: 'agaric',
+ 993: 'gyromitra',
+ 994: 'stinkhorn, carrion fungus',
+ 995: 'earthstar',
+ 996: 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa',
+ 997: 'bolete',
+ 998: 'ear, spike, capitulum',
+ 999: 'toilet tissue, toilet paper, bathroom tissue'}
\ No newline at end of file
diff --git a/README.md b/README.md
index cd82ea040710ab67ecce123baf0d90cc9ad51eaf..5a56d4dc63fcd1502b4aa2e5de3c26323f4be8a5 100644
--- a/README.md
+++ b/README.md
@@ -1,13 +1,124 @@
----
-title: RobustViT
-emoji: ⚡
-colorFrom: red
-colorTo: indigo
-sdk: gradio
-sdk_version: 3.0.11
-app_file: app.py
-pinned: false
-license: mit
----
-
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
+# RobustViT
+
+[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/hila-chefer/RobustViT/blob/master/RobustViT.ipynb)
+
+Official PyTorch implementation of **Optimizing Relevance Maps of Vision Transformers Improves Robustness**. This code allows to
+finetune the explainability maps of Vision Transformers to enhance robustness.
+
+The method employs loss functions directly to the explainability maps to ensure that the model is focused mostly on the foreground of the image:
+
+
+
+Using a short finetuning process with only 3 labeled examples from 500 classes, our method imrpoves robustness of ViT models across different model sizes and training techniques, even when data augmentations/ regularization are applied.
+
+## Producing Segmenataion Data
+### Using ImageNet-S
+To use the ImageNet-S labeled data, [download the `ImageNetS919` dataset](https://github.com/UnsupervisedSemanticSegmentation/ImageNet-S)
+
+### Using TokenCut for unsupervised segmentation
+1. Clone the TokenCut project
+ ```
+ git clone https://github.com/YangtaoWANG95/TokenCut.git
+ ```
+2. Install the dependencies
+ Python 3.7, PyTorch 1.7.1 and CUDA 11.2. Please refer to the official installation. If CUDA 10.2 has been properly installed:
+ ```
+ pip install torch==1.7.1 torchvision==0.8.2
+ ```
+ Followed by
+ ```
+ pip install -r TokenCut/requirements.txt
+
+3. Use the following command to extract the segmentation maps:
+ ```
+ python tokencut_generate_segmentation.py --img_path --out_dir
+ ```
+
+
+## Finetuning ViT models
+
+To finetune a pretrained ViT model use the `imagenet_finetune.py` script. Notice to uncomment the import line containing the pretrained model you
+wish to finetune.
+
+Usage example:
+
+```bash
+python imagenet_finetune.py --seg_data --data --gpu 0 --lr --lambda_seg --lambda_acc --lambda_background --lambda_foreground
+```
+
+Notes:
+
+* For all models we use :
+ * `lambda_seg=0.8`
+ * `lambda_acc=0.2`
+ * `lambda_background=2`
+ * `lambda_foreground=0.3`
+ * For **DeiT** models, a temprature is required as follows:
+ * `temprature=0.65` for DeiT-B
+ * `temprature=0.55` for DeiT-S
+ * The learning rates per model are:
+ * ViT-B: 3e-6
+ * ViT-L: 9e-7
+ * AR-S: 2e-6
+ * AR-B: 6e-7
+ * AR-L: 9e-7
+ * DeiT-S: 1e-6
+ * DeiT-B: 8e-7
+
+## Baseline methods
+Notice to uncomment the import line containing the pretrained model you wish to finetune in the code.
+
+### GradMask
+Run the following command:
+```bash
+python imagenet_finetune_gradmask.py --seg_data --data --gpu 0 --lr --lambda_seg --lambda_acc
+```
+All hyperparameters for the different models can be found in section D of the supplementary material.
+
+### Right for the Right Reasons
+Run the following command:
+```bash
+python imagenet_finetune_rrr.py --seg_data --data --gpu 0 --lr --lambda_seg --lambda_acc
+```
+All hyperparameters for the different models can be found in section D of the supplementary material.
+
+## Evaluation
+
+### Robustness Evaluation
+
+1. Download the evaluation datasets:
+ * [INet-A](https://github.com/hendrycks/natural-adv-examples)
+ * [INet-R](https://github.com/hendrycks/imagenet-r)
+ * [INet-v2](https://github.com/modestyachts/ImageNetV2)
+ * [ObjectNet](https://objectnet.dev/)
+ * [SI-Score](https://github.com/google-research/si-score)
+
+2. Run the following script to evaluate:
+
+```bash
+python imagenet_eval_robustness.py --data --batch-size --evaluate --checkpoint
+```
+* Notice to uncomment the import line containing the pretrained model you wish to evaluate in the code.
+* To evaluate the original model simply omit the `checkpoint` parameter.
+* For the INet-v2 dataset add `--isV2`.
+* For the ObjectNet dataset add `--isObjectNet`.
+* For the SI datasets add `--isSI`.
+
+### Segmentation Evaluation
+Our segmentation tests are based on the test in the official implementation of [Transformer Interpretability Beyond Attention Visualization](https://github.com/hila-chefer/Transformer-Explainability).
+1. [Download the ImageNet segmentation test set](https://github.com/hila-chefer/Transformer-Explainability#section-a-segmentation-results).
+2. Run the following script to evaluate:
+
+ ```bash
+PYTHONPATH=./:$PYTHONPATH python SegmentationTest/imagenet_seg_eval.py --imagenet-seg-path
+```
+* Notice to uncomment the import line containing the pretrained model you wish to evaluate in the code.
+
+### Credits
+* The TokenCut code is built on top of [LOST](https://github.com/valeoai/LOST), [DINO](https://github.com/facebookresearch/dino), [Segswap](https://github.com/XiSHEN0220/SegSwap), and [Bilateral_Sovlver](https://github.com/poolio/bilateral_solver).
+* Our ViT code is based on the [pytorch-image-models](https://github.com/rwightman/pytorch-image-models) repository.
+* Our ImageNet finetuning code is based on [code from the official PyTorch repo](https://github.com/pytorch/examples/blob/main/imagenet/main.py).
+* The code to convert ObjectNet classes to ImageNet classes was taken from [the torchprune repo](https://github.com/lucaslie/torchprune/blob/b753745b773c3ed259bf819d193ce8573d89efbb/src/torchprune/torchprune/util/datasets/objectnet.py).
+* The code to convert SI-Score classes to ImageNet classes was taken from [the official implementation](https://github.com/google-research/si-score).
+
+We would like to sincerely thank the authors for their great works.
diff --git a/RobustViT.ipynb b/RobustViT.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..060189c623e95ae77edad6564d42f9ce4cc98d30
--- /dev/null
+++ b/RobustViT.ipynb
@@ -0,0 +1,1454 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "name": "RobustViT.ipynb",
+ "provenance": [],
+ "collapsed_sections": [],
+ "authorship_tag": "ABX9TyNP00yXydKk0stZEJQyT5pO",
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ },
+ "accelerator": "GPU"
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "hzG96yZskJSy",
+ "outputId": "3eab22fa-e246-4cfb-d4a9-c35878cf75f2"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Cloning into 'RobustViT'...\n",
+ "remote: Enumerating objects: 139, done.\u001b[K\n",
+ "remote: Counting objects: 100% (139/139), done.\u001b[K\n",
+ "remote: Compressing objects: 100% (119/119), done.\u001b[K\n",
+ "remote: Total 139 (delta 54), reused 84 (delta 18), pack-reused 0\u001b[K\n",
+ "Receiving objects: 100% (139/139), 4.50 MiB | 17.11 MiB/s, done.\n",
+ "Resolving deltas: 100% (54/54), done.\n"
+ ]
+ }
+ ],
+ "source": [
+ "!git clone https://github.com/hila-chefer/RobustViT.git\n",
+ "\n",
+ "import os\n",
+ "os.chdir(f'./RobustViT')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!pip install timm\n",
+ "!pip install einops"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "hZK84BL3mZQg",
+ "outputId": "f9273be0-3410-47cf-f52e-2a45989e9b1f"
+ },
+ "execution_count": 2,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
+ "Collecting timm\n",
+ " Downloading timm-0.5.4-py3-none-any.whl (431 kB)\n",
+ "\u001b[K |████████████████████████████████| 431 kB 4.4 MB/s \n",
+ "\u001b[?25hRequirement already satisfied: torch>=1.4 in /usr/local/lib/python3.7/dist-packages (from timm) (1.11.0+cu113)\n",
+ "Requirement already satisfied: torchvision in /usr/local/lib/python3.7/dist-packages (from timm) (0.12.0+cu113)\n",
+ "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch>=1.4->timm) (4.2.0)\n",
+ "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from torchvision->timm) (1.21.6)\n",
+ "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.7/dist-packages (from torchvision->timm) (7.1.2)\n",
+ "Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from torchvision->timm) (2.23.0)\n",
+ "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->torchvision->timm) (3.0.4)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->torchvision->timm) (2022.5.18.1)\n",
+ "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->torchvision->timm) (1.24.3)\n",
+ "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->torchvision->timm) (2.10)\n",
+ "Installing collected packages: timm\n",
+ "Successfully installed timm-0.5.4\n",
+ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
+ "Collecting einops\n",
+ " Downloading einops-0.4.1-py3-none-any.whl (28 kB)\n",
+ "Installing collected packages: einops\n",
+ "Successfully installed einops-0.4.1\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from PIL import Image\n",
+ "import torchvision.transforms as transforms\n",
+ "import matplotlib.pyplot as plt\n",
+ "import torch\n",
+ "import numpy as np\n",
+ "import cv2\n",
+ "from CLS2IDX import CLS2IDX\n",
+ "\n",
+ "normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])\n",
+ "transform = transforms.Compose([\n",
+ " transforms.Resize(256),\n",
+ " transforms.CenterCrop(224),\n",
+ " transforms.ToTensor(),\n",
+ " normalize,\n",
+ "])\n",
+ "transform_224 = transforms.Compose([\n",
+ " transforms.ToTensor(),\n",
+ " normalize,\n",
+ "])\n",
+ "\n",
+ "# create heatmap from mask on image\n",
+ "def show_cam_on_image(img, mask):\n",
+ " heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)\n",
+ " heatmap = np.float32(heatmap) / 255\n",
+ " cam = heatmap + np.float32(img)\n",
+ " cam = cam / np.max(cam)\n",
+ " return cam"
+ ],
+ "metadata": {
+ "id": "uqDTsTS2k8pl"
+ },
+ "execution_count": 3,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from pydrive.auth import GoogleAuth\n",
+ "from pydrive.drive import GoogleDrive\n",
+ "from google.colab import auth\n",
+ "from oauth2client.client import GoogleCredentials\n",
+ "\n",
+ "# Authenticate and create the PyDrive client.\n",
+ "auth.authenticate_user()\n",
+ "gauth = GoogleAuth()\n",
+ "gauth.credentials = GoogleCredentials.get_application_default()\n",
+ "drive = GoogleDrive(gauth)\n",
+ "\n",
+ "# downloads weights\n",
+ "ids = ['1jbWiuBrL4sKpAjG3x4oGbs3WOC2UdbIb', '1DHKX_s8rVCDiX4pwnuCCZdGWsOl4SFMn', '1vDmuvbdLbYVAqWz6yVM4vT1Wdzt8KV-g']\n",
+ "for file_id in ids:\n",
+ " downloaded = drive.CreateFile({'id':file_id})\n",
+ " downloaded.FetchMetadata(fetch_all=True)\n",
+ " downloaded.GetContentFile(downloaded.metadata['title'])"
+ ],
+ "metadata": {
+ "id": "eImo3TAenFbo"
+ },
+ "execution_count": 4,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "model_name = 'ar_base' #@param ['ar_base','vit_base', 'deit_base']\n",
+ "\n",
+ "if model_name == 'ar_base':\n",
+ " from ViT.ViT_new import vit_base_patch16_224 as vit\n",
+ "\n",
+ " # initialize ViT pretrained\n",
+ " model = vit(pretrained=True).cuda()\n",
+ " model.eval()\n",
+ "\n",
+ " model_finetuned = vit().cuda()\n",
+ " checkpoint = torch.load('ar_base.tar')\n",
+ "\n",
+ "if model_name == 'vit_base':\n",
+ " from ViT.ViT import vit_base_patch16_224 as vit\n",
+ "\n",
+ " # initialize ViT pretrained\n",
+ " model = vit(pretrained=True).cuda()\n",
+ " model.eval()\n",
+ "\n",
+ " model_finetuned = vit().cuda()\n",
+ " checkpoint = torch.load('vit_base.tar')\n",
+ "\n",
+ "if model_name == 'deit_base':\n",
+ " from ViT.ViT import deit_base_patch16_224 as vit\n",
+ "\n",
+ " # initialize ViT pretrained\n",
+ " model = vit(pretrained=True).cuda()\n",
+ " model.eval()\n",
+ "\n",
+ " model_finetuned = vit().cuda()\n",
+ " checkpoint = torch.load('deit_base.tar')\n",
+ "\n",
+ "model_finetuned.load_state_dict(checkpoint['state_dict'])\n",
+ "model_finetuned.eval()"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "cellView": "form",
+ "id": "tayM9OIWlLOT",
+ "outputId": "ac6f819a-ef56-4bbf-c25c-5f9cc9b2bb5c"
+ },
+ "execution_count": 7,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "VisionTransformer(\n",
+ " (patch_embed): PatchEmbed(\n",
+ " (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))\n",
+ " (norm): Identity()\n",
+ " )\n",
+ " (pos_drop): Dropout(p=0.0, inplace=False)\n",
+ " (blocks): ModuleList(\n",
+ " (0): Block(\n",
+ " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (attn): Attention(\n",
+ " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
+ " (attn_drop): Dropout(p=0.0, inplace=False)\n",
+ " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (proj_drop): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " (drop_path): Identity()\n",
+ " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (mlp): Mlp(\n",
+ " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " (act): GELU()\n",
+ " (drop1): Dropout(p=0.0, inplace=False)\n",
+ " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (drop2): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (1): Block(\n",
+ " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (attn): Attention(\n",
+ " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
+ " (attn_drop): Dropout(p=0.0, inplace=False)\n",
+ " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (proj_drop): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " (drop_path): Identity()\n",
+ " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (mlp): Mlp(\n",
+ " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " (act): GELU()\n",
+ " (drop1): Dropout(p=0.0, inplace=False)\n",
+ " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (drop2): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (2): Block(\n",
+ " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (attn): Attention(\n",
+ " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
+ " (attn_drop): Dropout(p=0.0, inplace=False)\n",
+ " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (proj_drop): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " (drop_path): Identity()\n",
+ " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (mlp): Mlp(\n",
+ " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " (act): GELU()\n",
+ " (drop1): Dropout(p=0.0, inplace=False)\n",
+ " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (drop2): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (3): Block(\n",
+ " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (attn): Attention(\n",
+ " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
+ " (attn_drop): Dropout(p=0.0, inplace=False)\n",
+ " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (proj_drop): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " (drop_path): Identity()\n",
+ " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (mlp): Mlp(\n",
+ " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " (act): GELU()\n",
+ " (drop1): Dropout(p=0.0, inplace=False)\n",
+ " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (drop2): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (4): Block(\n",
+ " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (attn): Attention(\n",
+ " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
+ " (attn_drop): Dropout(p=0.0, inplace=False)\n",
+ " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (proj_drop): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " (drop_path): Identity()\n",
+ " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (mlp): Mlp(\n",
+ " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " (act): GELU()\n",
+ " (drop1): Dropout(p=0.0, inplace=False)\n",
+ " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (drop2): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (5): Block(\n",
+ " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (attn): Attention(\n",
+ " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
+ " (attn_drop): Dropout(p=0.0, inplace=False)\n",
+ " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (proj_drop): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " (drop_path): Identity()\n",
+ " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (mlp): Mlp(\n",
+ " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " (act): GELU()\n",
+ " (drop1): Dropout(p=0.0, inplace=False)\n",
+ " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (drop2): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (6): Block(\n",
+ " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (attn): Attention(\n",
+ " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
+ " (attn_drop): Dropout(p=0.0, inplace=False)\n",
+ " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (proj_drop): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " (drop_path): Identity()\n",
+ " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (mlp): Mlp(\n",
+ " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " (act): GELU()\n",
+ " (drop1): Dropout(p=0.0, inplace=False)\n",
+ " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (drop2): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (7): Block(\n",
+ " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (attn): Attention(\n",
+ " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
+ " (attn_drop): Dropout(p=0.0, inplace=False)\n",
+ " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (proj_drop): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " (drop_path): Identity()\n",
+ " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (mlp): Mlp(\n",
+ " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " (act): GELU()\n",
+ " (drop1): Dropout(p=0.0, inplace=False)\n",
+ " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (drop2): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (8): Block(\n",
+ " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (attn): Attention(\n",
+ " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
+ " (attn_drop): Dropout(p=0.0, inplace=False)\n",
+ " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (proj_drop): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " (drop_path): Identity()\n",
+ " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (mlp): Mlp(\n",
+ " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " (act): GELU()\n",
+ " (drop1): Dropout(p=0.0, inplace=False)\n",
+ " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (drop2): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (9): Block(\n",
+ " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (attn): Attention(\n",
+ " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
+ " (attn_drop): Dropout(p=0.0, inplace=False)\n",
+ " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (proj_drop): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " (drop_path): Identity()\n",
+ " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (mlp): Mlp(\n",
+ " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " (act): GELU()\n",
+ " (drop1): Dropout(p=0.0, inplace=False)\n",
+ " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (drop2): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (10): Block(\n",
+ " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (attn): Attention(\n",
+ " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
+ " (attn_drop): Dropout(p=0.0, inplace=False)\n",
+ " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (proj_drop): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " (drop_path): Identity()\n",
+ " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (mlp): Mlp(\n",
+ " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " (act): GELU()\n",
+ " (drop1): Dropout(p=0.0, inplace=False)\n",
+ " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (drop2): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (11): Block(\n",
+ " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (attn): Attention(\n",
+ " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
+ " (attn_drop): Dropout(p=0.0, inplace=False)\n",
+ " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (proj_drop): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " (drop_path): Identity()\n",
+ " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (mlp): Mlp(\n",
+ " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " (act): GELU()\n",
+ " (drop1): Dropout(p=0.0, inplace=False)\n",
+ " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (drop2): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (pre_logits): Identity()\n",
+ " (head): Linear(in_features=768, out_features=1000, bias=True)\n",
+ ")"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 7
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "start_layer = 0\n",
+ "\n",
+ "# rule 5 from paper\n",
+ "def avg_heads(cam, grad):\n",
+ " cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1])\n",
+ " grad = grad.reshape(-1, grad.shape[-2], grad.shape[-1])\n",
+ " cam = grad * cam\n",
+ " cam = cam.clamp(min=0).mean(dim=0)\n",
+ " return cam\n",
+ "\n",
+ "# rule 6 from paper\n",
+ "def apply_self_attention_rules(R_ss, cam_ss):\n",
+ " R_ss_addition = torch.matmul(cam_ss, R_ss)\n",
+ " return R_ss_addition\n",
+ "\n",
+ "def generate_relevance(model, input, index=None):\n",
+ " output = model(input, register_hook=True)\n",
+ " if index == None:\n",
+ " index = np.argmax(output.cpu().data.numpy(), axis=-1)\n",
+ "\n",
+ " one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)\n",
+ " one_hot[0, index] = 1\n",
+ " one_hot_vector = one_hot\n",
+ " one_hot = torch.from_numpy(one_hot).requires_grad_(True)\n",
+ " one_hot = torch.sum(one_hot.cuda() * output)\n",
+ " model.zero_grad()\n",
+ " one_hot.backward(retain_graph=True)\n",
+ "\n",
+ " num_tokens = model.blocks[0].attn.get_attention_map().shape[-1]\n",
+ " R = torch.eye(num_tokens, num_tokens).cuda()\n",
+ " for i,blk in enumerate(model.blocks):\n",
+ " if i < start_layer:\n",
+ " continue\n",
+ " grad = blk.attn.get_attn_gradients()\n",
+ " cam = blk.attn.get_attention_map()\n",
+ " cam = avg_heads(cam, grad)\n",
+ " R += apply_self_attention_rules(R.cuda(), cam.cuda())\n",
+ " return R[0, 1:]"
+ ],
+ "metadata": {
+ "id": "tKp64OSWlC7w"
+ },
+ "execution_count": 8,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "def generate_visualization(model, original_image, class_index=None):\n",
+ " with torch.enable_grad():\n",
+ " transformer_attribution = generate_relevance(model, original_image.unsqueeze(0).cuda(), index=class_index).detach()\n",
+ " transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)\n",
+ " transformer_attribution = torch.nn.functional.interpolate(transformer_attribution, scale_factor=16, mode='bilinear')\n",
+ " transformer_attribution = transformer_attribution.reshape(224, 224).cuda().data.cpu().numpy()\n",
+ " transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min())\n",
+ " \n",
+ " image_transformer_attribution = original_image.permute(1, 2, 0).data.cpu().numpy()\n",
+ " image_transformer_attribution = (image_transformer_attribution - image_transformer_attribution.min()) / (image_transformer_attribution.max() - image_transformer_attribution.min())\n",
+ " vis = show_cam_on_image(image_transformer_attribution, transformer_attribution)\n",
+ " vis = np.uint8(255 * vis)\n",
+ " vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)\n",
+ " return vis\n",
+ "\n",
+ "def print_top_classes(predictions, **kwargs): \n",
+ " # Print Top-5 predictions\n",
+ " prob = torch.softmax(predictions, dim=1)\n",
+ " class_indices = predictions.data.topk(5, dim=1)[1][0].tolist()\n",
+ " max_str_len = 0\n",
+ " class_names = []\n",
+ " for cls_idx in class_indices:\n",
+ " class_names.append(CLS2IDX[cls_idx])\n",
+ " if len(CLS2IDX[cls_idx]) > max_str_len:\n",
+ " max_str_len = len(CLS2IDX[cls_idx])\n",
+ " \n",
+ " print('Top 5 classes:')\n",
+ " for cls_idx in class_indices:\n",
+ " output_string = '\\t{} : {}'.format(cls_idx, CLS2IDX[cls_idx])\n",
+ " output_string += ' ' * (max_str_len - len(CLS2IDX[cls_idx])) + '\\t\\t'\n",
+ " output_string += 'value = {:.3f}\\t prob = {:.1f}%'.format(predictions[0, cls_idx], 100 * prob[0, cls_idx])\n",
+ " print(output_string)"
+ ],
+ "metadata": {
+ "id": "rmQ9pacLoGze"
+ },
+ "execution_count": 9,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# ImageNet-A"
+ ],
+ "metadata": {
+ "id": "5o-p6euNMibE"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "with torch.no_grad():\n",
+ " image = Image.open(f'samples/{model_name}/a.png')\n",
+ " dog_cat_image = transform_224(image)\n",
+ "\n",
+ " fig, axs = plt.subplots(1, 2)\n",
+ " fig.set_size_inches(10, 7)\n",
+ " axs[0].imshow(image);\n",
+ " axs[0].axis('off');\n",
+ "\n",
+ " output = model(dog_cat_image.unsqueeze(0).cuda())\n",
+ " print(\"original model\")\n",
+ " print_top_classes(output)\n",
+ "\n",
+ " out = generate_visualization(model, dog_cat_image)\n",
+ "\n",
+ " fig.suptitle('original model',y=0.8)\n",
+ " axs[1].imshow(out);\n",
+ " axs[1].axis('off');\n",
+ "\n",
+ " fig, axs = plt.subplots(1, 2)\n",
+ " fig.set_size_inches(10, 7)\n",
+ " axs[0].imshow(image);\n",
+ " axs[0].axis('off');\n",
+ " output = model_finetuned(dog_cat_image.unsqueeze(0).cuda())\n",
+ " print(\"finetuned model\")\n",
+ " print_top_classes(output)\n",
+ "\n",
+ " out = generate_visualization(model_finetuned, dog_cat_image)\n",
+ "\n",
+ " fig.suptitle('finetuned model',y=0.8)\n",
+ " axs[1].imshow(out);\n",
+ " axs[1].axis('off');"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 842
+ },
+ "id": "q8hsWi_WMlkb",
+ "outputId": "2a89083c-53b6-4b5e-c7b0-c01d9c90446c"
+ },
+ "execution_count": 10,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "original model\n",
+ "Top 5 classes:\n",
+ "\t829 : streetcar, tram, tramcar, trolley, trolley car\t\tvalue = 10.911\t prob = 57.2%\n",
+ "\t874 : trolleybus, trolley coach, trackless trolley \t\tvalue = 10.221\t prob = 28.7%\n",
+ "\t466 : bullet train, bullet \t\tvalue = 6.897\t prob = 1.0%\n",
+ "\t733 : pole \t\tvalue = 6.878\t prob = 1.0%\n",
+ "\t547 : electric locomotive \t\tvalue = 6.626\t prob = 0.8%\n",
+ "finetuned model\n",
+ "Top 5 classes:\n",
+ "\t847 : tank, army tank, armored combat vehicle, armoured combat vehicle\t\tvalue = 11.573\t prob = 60.1%\n",
+ "\t408 : amphibian, amphibious vehicle \t\tvalue = 10.085\t prob = 13.6%\n",
+ "\t874 : trolleybus, trolley coach, trackless trolley \t\tvalue = 9.585\t prob = 8.2%\n",
+ "\t829 : streetcar, tram, tramcar, trolley, trolley car \t\tvalue = 9.583\t prob = 8.2%\n",
+ "\t586 : half track \t\tvalue = 7.935\t prob = 1.6%\n"
+ ]
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "