move model to onnx format
Browse files- examples/AFRICAN OYESTER CATCHER.jpg +0 -0
- examples/HOUSE SPARROW.jpg +0 -0
- examples/PEACOCK.jpg +0 -0
- examples/VERMILION FLYCATCHER.jpg +0 -0
- examples/WILD TURKEY.jpg +0 -0
- model/birds_name_mapping.json +452 -0
- notebooks/onnx-testing.ipynb +0 -0
- notebooks/pytorch-birds-resnet34.ipynb +0 -0
- notebooks/torch-to-onnx.ipynb +389 -0
examples/AFRICAN OYESTER CATCHER.jpg
ADDED
![]() |
examples/HOUSE SPARROW.jpg
ADDED
![]() |
examples/PEACOCK.jpg
ADDED
![]() |
examples/VERMILION FLYCATCHER.jpg
ADDED
![]() |
examples/WILD TURKEY.jpg
ADDED
![]() |
model/birds_name_mapping.json
ADDED
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"0": "ABBOTTS BABBLER",
|
3 |
+
"1": "ABBOTTS BOOBY",
|
4 |
+
"2": "ABYSSINIAN GROUND HORNBILL",
|
5 |
+
"3": "AFRICAN CROWNED CRANE",
|
6 |
+
"4": "AFRICAN EMERALD CUCKOO",
|
7 |
+
"5": "AFRICAN FIREFINCH",
|
8 |
+
"6": "AFRICAN OYSTER CATCHER",
|
9 |
+
"7": "AFRICAN PIED HORNBILL",
|
10 |
+
"8": "ALBATROSS",
|
11 |
+
"9": "ALBERTS TOWHEE",
|
12 |
+
"10": "ALEXANDRINE PARAKEET",
|
13 |
+
"11": "ALPINE CHOUGH",
|
14 |
+
"12": "ALTAMIRA YELLOWTHROAT",
|
15 |
+
"13": "AMERICAN AVOCET",
|
16 |
+
"14": "AMERICAN BITTERN",
|
17 |
+
"15": "AMERICAN COOT",
|
18 |
+
"16": "AMERICAN FLAMINGO",
|
19 |
+
"17": "AMERICAN GOLDFINCH",
|
20 |
+
"18": "AMERICAN KESTREL",
|
21 |
+
"19": "AMERICAN PIPIT",
|
22 |
+
"20": "AMERICAN REDSTART",
|
23 |
+
"21": "AMERICAN WIGEON",
|
24 |
+
"22": "AMETHYST WOODSTAR",
|
25 |
+
"23": "ANDEAN GOOSE",
|
26 |
+
"24": "ANDEAN LAPWING",
|
27 |
+
"25": "ANDEAN SISKIN",
|
28 |
+
"26": "ANHINGA",
|
29 |
+
"27": "ANIANIAU",
|
30 |
+
"28": "ANNAS HUMMINGBIRD",
|
31 |
+
"29": "ANTBIRD",
|
32 |
+
"30": "ANTILLEAN EUPHONIA",
|
33 |
+
"31": "APAPANE",
|
34 |
+
"32": "APOSTLEBIRD",
|
35 |
+
"33": "ARARIPE MANAKIN",
|
36 |
+
"34": "ASHY STORM PETREL",
|
37 |
+
"35": "ASHY THRUSHBIRD",
|
38 |
+
"36": "ASIAN CRESTED IBIS",
|
39 |
+
"37": "ASIAN DOLLARD BIRD",
|
40 |
+
"38": "AUCKLAND SHAQ",
|
41 |
+
"39": "AUSTRAL CANASTERO",
|
42 |
+
"40": "AUSTRALASIAN FIGBIRD",
|
43 |
+
"41": "AVADAVAT",
|
44 |
+
"42": "AZARAS SPINETAIL",
|
45 |
+
"43": "AZURE BREASTED PITTA",
|
46 |
+
"44": "AZURE JAY",
|
47 |
+
"45": "AZURE TANAGER",
|
48 |
+
"46": "AZURE TIT",
|
49 |
+
"47": "BAIKAL TEAL",
|
50 |
+
"48": "BALD EAGLE",
|
51 |
+
"49": "BALD IBIS",
|
52 |
+
"50": "BALI STARLING",
|
53 |
+
"51": "BALTIMORE ORIOLE",
|
54 |
+
"52": "BANANAQUIT",
|
55 |
+
"53": "BAND TAILED GUAN",
|
56 |
+
"54": "BANDED BROADBILL",
|
57 |
+
"55": "BANDED PITA",
|
58 |
+
"56": "BANDED STILT",
|
59 |
+
"57": "BAR-TAILED GODWIT",
|
60 |
+
"58": "BARN OWL",
|
61 |
+
"59": "BARN SWALLOW",
|
62 |
+
"60": "BARRED PUFFBIRD",
|
63 |
+
"61": "BARROWS GOLDENEYE",
|
64 |
+
"62": "BAY-BREASTED WARBLER",
|
65 |
+
"63": "BEARDED BARBET",
|
66 |
+
"64": "BEARDED BELLBIRD",
|
67 |
+
"65": "BEARDED REEDLING",
|
68 |
+
"66": "BELTED KINGFISHER",
|
69 |
+
"67": "BIRD OF PARADISE",
|
70 |
+
"68": "BLACK & YELLOW BROADBILL",
|
71 |
+
"69": "BLACK BAZA",
|
72 |
+
"70": "BLACK COCKATO",
|
73 |
+
"71": "BLACK FRANCOLIN",
|
74 |
+
"72": "BLACK SKIMMER",
|
75 |
+
"73": "BLACK SWAN",
|
76 |
+
"74": "BLACK TAIL CRAKE",
|
77 |
+
"75": "BLACK THROATED BUSHTIT",
|
78 |
+
"76": "BLACK THROATED WARBLER",
|
79 |
+
"77": "BLACK VENTED SHEARWATER",
|
80 |
+
"78": "BLACK VULTURE",
|
81 |
+
"79": "BLACK-CAPPED CHICKADEE",
|
82 |
+
"80": "BLACK-NECKED GREBE",
|
83 |
+
"81": "BLACK-THROATED SPARROW",
|
84 |
+
"82": "BLACKBURNIAM WARBLER",
|
85 |
+
"83": "BLONDE CRESTED WOODPECKER",
|
86 |
+
"84": "BLOOD PHEASANT",
|
87 |
+
"85": "BLUE COAU",
|
88 |
+
"86": "BLUE DACNIS",
|
89 |
+
"87": "BLUE GROUSE",
|
90 |
+
"88": "BLUE HERON",
|
91 |
+
"89": "BLUE MALKOHA",
|
92 |
+
"90": "BLUE THROATED TOUCANET",
|
93 |
+
"91": "BOBOLINK",
|
94 |
+
"92": "BORNEAN BRISTLEHEAD",
|
95 |
+
"93": "BORNEAN LEAFBIRD",
|
96 |
+
"94": "BORNEAN PHEASANT",
|
97 |
+
"95": "BRANDT CORMARANT",
|
98 |
+
"96": "BREWERS BLACKBIRD",
|
99 |
+
"97": "BROWN CREPPER",
|
100 |
+
"98": "BROWN NOODY",
|
101 |
+
"99": "BROWN THRASHER",
|
102 |
+
"100": "BUFFLEHEAD",
|
103 |
+
"101": "BULWERS PHEASANT",
|
104 |
+
"102": "BURCHELLS COURSER",
|
105 |
+
"103": "BUSH TURKEY",
|
106 |
+
"104": "CAATINGA CACHOLOTE",
|
107 |
+
"105": "CACTUS WREN",
|
108 |
+
"106": "CALIFORNIA CONDOR",
|
109 |
+
"107": "CALIFORNIA GULL",
|
110 |
+
"108": "CALIFORNIA QUAIL",
|
111 |
+
"109": "CAMPO FLICKER",
|
112 |
+
"110": "CANARY",
|
113 |
+
"111": "CAPE GLOSSY STARLING",
|
114 |
+
"112": "CAPE LONGCLAW",
|
115 |
+
"113": "CAPE MAY WARBLER",
|
116 |
+
"114": "CAPE ROCK THRUSH",
|
117 |
+
"115": "CAPPED HERON",
|
118 |
+
"116": "CAPUCHINBIRD",
|
119 |
+
"117": "CARMINE BEE-EATER",
|
120 |
+
"118": "CASPIAN TERN",
|
121 |
+
"119": "CASSOWARY",
|
122 |
+
"120": "CEDAR WAXWING",
|
123 |
+
"121": "CERULEAN WARBLER",
|
124 |
+
"122": "CHARA DE COLLAR",
|
125 |
+
"123": "CHATTERING LORY",
|
126 |
+
"124": "CHESTNET BELLIED EUPHONIA",
|
127 |
+
"125": "CHINESE BAMBOO PARTRIDGE",
|
128 |
+
"126": "CHINESE POND HERON",
|
129 |
+
"127": "CHIPPING SPARROW",
|
130 |
+
"128": "CHUCAO TAPACULO",
|
131 |
+
"129": "CHUKAR PARTRIDGE",
|
132 |
+
"130": "CINNAMON ATTILA",
|
133 |
+
"131": "CINNAMON FLYCATCHER",
|
134 |
+
"132": "CINNAMON TEAL",
|
135 |
+
"133": "CLARKS NUTCRACKER",
|
136 |
+
"134": "COCK OF THE ROCK",
|
137 |
+
"135": "COCKATOO",
|
138 |
+
"136": "COLLARED ARACARI",
|
139 |
+
"137": "COMMON FIRECREST",
|
140 |
+
"138": "COMMON GRACKLE",
|
141 |
+
"139": "COMMON HOUSE MARTIN",
|
142 |
+
"140": "COMMON IORA",
|
143 |
+
"141": "COMMON LOON",
|
144 |
+
"142": "COMMON POORWILL",
|
145 |
+
"143": "COMMON STARLING",
|
146 |
+
"144": "COPPERY TAILED COUCAL",
|
147 |
+
"145": "CRAB PLOVER",
|
148 |
+
"146": "CRANE HAWK",
|
149 |
+
"147": "CREAM COLORED WOODPECKER",
|
150 |
+
"148": "CRESTED AUKLET",
|
151 |
+
"149": "CRESTED CARACARA",
|
152 |
+
"150": "CRESTED COUA",
|
153 |
+
"151": "CRESTED FIREBACK",
|
154 |
+
"152": "CRESTED KINGFISHER",
|
155 |
+
"153": "CRESTED NUTHATCH",
|
156 |
+
"154": "CRESTED OROPENDOLA",
|
157 |
+
"155": "CRESTED SHRIKETIT",
|
158 |
+
"156": "CRIMSON CHAT",
|
159 |
+
"157": "CRIMSON SUNBIRD",
|
160 |
+
"158": "CROW",
|
161 |
+
"159": "CROWNED PIGEON",
|
162 |
+
"160": "CUBAN TODY",
|
163 |
+
"161": "CUBAN TROGON",
|
164 |
+
"162": "CURL CRESTED ARACURI",
|
165 |
+
"163": "D-ARNAUDS BARBET",
|
166 |
+
"164": "DALMATIAN PELICAN",
|
167 |
+
"165": "DARJEELING WOODPECKER",
|
168 |
+
"166": "DARK EYED JUNCO",
|
169 |
+
"167": "DARWINS FLYCATCHER",
|
170 |
+
"168": "DAURIAN REDSTART",
|
171 |
+
"169": "DEMOISELLE CRANE",
|
172 |
+
"170": "DOUBLE BARRED FINCH",
|
173 |
+
"171": "DOUBLE BRESTED CORMARANT",
|
174 |
+
"172": "DOUBLE EYED FIG PARROT",
|
175 |
+
"173": "DOWNY WOODPECKER",
|
176 |
+
"174": "DUSKY LORY",
|
177 |
+
"175": "DUSKY ROBIN",
|
178 |
+
"176": "EARED PITA",
|
179 |
+
"177": "EASTERN BLUEBIRD",
|
180 |
+
"178": "EASTERN BLUEBONNET",
|
181 |
+
"179": "EASTERN GOLDEN WEAVER",
|
182 |
+
"180": "EASTERN MEADOWLARK",
|
183 |
+
"181": "EASTERN ROSELLA",
|
184 |
+
"182": "EASTERN TOWEE",
|
185 |
+
"183": "EASTERN WIP POOR WILL",
|
186 |
+
"184": "ECUADORIAN HILLSTAR",
|
187 |
+
"185": "EGYPTIAN GOOSE",
|
188 |
+
"186": "ELEGANT TROGON",
|
189 |
+
"187": "ELLIOTS PHEASANT",
|
190 |
+
"188": "EMERALD TANAGER",
|
191 |
+
"189": "EMPEROR PENGUIN",
|
192 |
+
"190": "EMU",
|
193 |
+
"191": "ENGGANO MYNA",
|
194 |
+
"192": "EURASIAN BULLFINCH",
|
195 |
+
"193": "EURASIAN GOLDEN ORIOLE",
|
196 |
+
"194": "EURASIAN MAGPIE",
|
197 |
+
"195": "EUROPEAN GOLDFINCH",
|
198 |
+
"196": "EUROPEAN TURTLE DOVE",
|
199 |
+
"197": "EVENING GROSBEAK",
|
200 |
+
"198": "FAIRY BLUEBIRD",
|
201 |
+
"199": "FAIRY PENGUIN",
|
202 |
+
"200": "FAIRY TERN",
|
203 |
+
"201": "FAN TAILED WIDOW",
|
204 |
+
"202": "FASCIATED WREN",
|
205 |
+
"203": "FIERY MINIVET",
|
206 |
+
"204": "FIORDLAND PENGUIN",
|
207 |
+
"205": "FIRE TAILLED MYZORNIS",
|
208 |
+
"206": "FLAME BOWERBIRD",
|
209 |
+
"207": "FLAME TANAGER",
|
210 |
+
"208": "FRIGATE",
|
211 |
+
"209": "GAMBELS QUAIL",
|
212 |
+
"210": "GANG GANG COCKATOO",
|
213 |
+
"211": "GILA WOODPECKER",
|
214 |
+
"212": "GILDED FLICKER",
|
215 |
+
"213": "GLOSSY IBIS",
|
216 |
+
"214": "GO AWAY BIRD",
|
217 |
+
"215": "GOLD WING WARBLER",
|
218 |
+
"216": "GOLDEN BOWER BIRD",
|
219 |
+
"217": "GOLDEN CHEEKED WARBLER",
|
220 |
+
"218": "GOLDEN CHLOROPHONIA",
|
221 |
+
"219": "GOLDEN EAGLE",
|
222 |
+
"220": "GOLDEN PARAKEET",
|
223 |
+
"221": "GOLDEN PHEASANT",
|
224 |
+
"222": "GOLDEN PIPIT",
|
225 |
+
"223": "GOULDIAN FINCH",
|
226 |
+
"224": "GRANDALA",
|
227 |
+
"225": "GRAY CATBIRD",
|
228 |
+
"226": "GRAY KINGBIRD",
|
229 |
+
"227": "GRAY PARTRIDGE",
|
230 |
+
"228": "GREAT GRAY OWL",
|
231 |
+
"229": "GREAT JACAMAR",
|
232 |
+
"230": "GREAT KISKADEE",
|
233 |
+
"231": "GREAT POTOO",
|
234 |
+
"232": "GREAT TINAMOU",
|
235 |
+
"233": "GREAT XENOPS",
|
236 |
+
"234": "GREATER PEWEE",
|
237 |
+
"235": "GREATOR SAGE GROUSE",
|
238 |
+
"236": "GREEN BROADBILL",
|
239 |
+
"237": "GREEN JAY",
|
240 |
+
"238": "GREEN MAGPIE",
|
241 |
+
"239": "GREY CUCKOOSHRIKE",
|
242 |
+
"240": "GREY PLOVER",
|
243 |
+
"241": "GROVED BILLED ANI",
|
244 |
+
"242": "GUINEA TURACO",
|
245 |
+
"243": "GUINEAFOWL",
|
246 |
+
"244": "GURNEYS PITTA",
|
247 |
+
"245": "GYRFALCON",
|
248 |
+
"246": "HAMERKOP",
|
249 |
+
"247": "HARLEQUIN DUCK",
|
250 |
+
"248": "HARLEQUIN QUAIL",
|
251 |
+
"249": "HARPY EAGLE",
|
252 |
+
"250": "HAWAIIAN GOOSE",
|
253 |
+
"251": "HAWFINCH",
|
254 |
+
"252": "HELMET VANGA",
|
255 |
+
"253": "HEPATIC TANAGER",
|
256 |
+
"254": "HIMALAYAN BLUETAIL",
|
257 |
+
"255": "HIMALAYAN MONAL",
|
258 |
+
"256": "HOATZIN",
|
259 |
+
"257": "HOODED MERGANSER",
|
260 |
+
"258": "HOOPOES",
|
261 |
+
"259": "HORNED GUAN",
|
262 |
+
"260": "HORNED LARK",
|
263 |
+
"261": "HORNED SUNGEM",
|
264 |
+
"262": "HOUSE FINCH",
|
265 |
+
"263": "HOUSE SPARROW",
|
266 |
+
"264": "HYACINTH MACAW",
|
267 |
+
"265": "IBERIAN MAGPIE",
|
268 |
+
"266": "IBISBILL",
|
269 |
+
"267": "IMPERIAL SHAQ",
|
270 |
+
"268": "INCA TERN",
|
271 |
+
"269": "INDIAN BUSTARD",
|
272 |
+
"270": "INDIAN PITTA",
|
273 |
+
"271": "INDIAN ROLLER",
|
274 |
+
"272": "INDIAN VULTURE",
|
275 |
+
"273": "INDIGO BUNTING",
|
276 |
+
"274": "INDIGO FLYCATCHER",
|
277 |
+
"275": "INLAND DOTTEREL",
|
278 |
+
"276": "IVORY BILLED ARACARI",
|
279 |
+
"277": "IVORY GULL",
|
280 |
+
"278": "IWI",
|
281 |
+
"279": "JABIRU",
|
282 |
+
"280": "JACK SNIPE",
|
283 |
+
"281": "JANDAYA PARAKEET",
|
284 |
+
"282": "JAPANESE ROBIN",
|
285 |
+
"283": "JAVA SPARROW",
|
286 |
+
"284": "JOCOTOCO ANTPITTA",
|
287 |
+
"285": "KAGU",
|
288 |
+
"286": "KAKAPO",
|
289 |
+
"287": "KILLDEAR",
|
290 |
+
"288": "KING EIDER",
|
291 |
+
"289": "KING VULTURE",
|
292 |
+
"290": "KIWI",
|
293 |
+
"291": "KOOKABURRA",
|
294 |
+
"292": "LARK BUNTING",
|
295 |
+
"293": "LAZULI BUNTING",
|
296 |
+
"294": "LESSER ADJUTANT",
|
297 |
+
"295": "LILAC ROLLER",
|
298 |
+
"296": "LITTLE AUK",
|
299 |
+
"297": "LOGGERHEAD SHRIKE",
|
300 |
+
"298": "LONG-EARED OWL",
|
301 |
+
"299": "MAGPIE GOOSE",
|
302 |
+
"300": "MALABAR HORNBILL",
|
303 |
+
"301": "MALACHITE KINGFISHER",
|
304 |
+
"302": "MALAGASY WHITE EYE",
|
305 |
+
"303": "MALEO",
|
306 |
+
"304": "MALLARD DUCK",
|
307 |
+
"305": "MANDRIN DUCK",
|
308 |
+
"306": "MANGROVE CUCKOO",
|
309 |
+
"307": "MARABOU STORK",
|
310 |
+
"308": "MASKED BOOBY",
|
311 |
+
"309": "MASKED LAPWING",
|
312 |
+
"310": "MCKAYS BUNTING",
|
313 |
+
"311": "MIKADO PHEASANT",
|
314 |
+
"312": "MOURNING DOVE",
|
315 |
+
"313": "MYNA",
|
316 |
+
"314": "NICOBAR PIGEON",
|
317 |
+
"315": "NOISY FRIARBIRD",
|
318 |
+
"316": "NORTHERN BEARDLESS TYRANNULET",
|
319 |
+
"317": "NORTHERN CARDINAL",
|
320 |
+
"318": "NORTHERN FLICKER",
|
321 |
+
"319": "NORTHERN FULMAR",
|
322 |
+
"320": "NORTHERN GANNET",
|
323 |
+
"321": "NORTHERN GOSHAWK",
|
324 |
+
"322": "NORTHERN JACANA",
|
325 |
+
"323": "NORTHERN MOCKINGBIRD",
|
326 |
+
"324": "NORTHERN PARULA",
|
327 |
+
"325": "NORTHERN RED BISHOP",
|
328 |
+
"326": "NORTHERN SHOVELER",
|
329 |
+
"327": "OCELLATED TURKEY",
|
330 |
+
"328": "OKINAWA RAIL",
|
331 |
+
"329": "ORANGE BRESTED BUNTING",
|
332 |
+
"330": "ORIENTAL BAY OWL",
|
333 |
+
"331": "OSPREY",
|
334 |
+
"332": "OSTRICH",
|
335 |
+
"333": "OVENBIRD",
|
336 |
+
"334": "OYSTER CATCHER",
|
337 |
+
"335": "PAINTED BUNTING",
|
338 |
+
"336": "PALILA",
|
339 |
+
"337": "PARADISE TANAGER",
|
340 |
+
"338": "PARAKETT AKULET",
|
341 |
+
"339": "PARUS MAJOR",
|
342 |
+
"340": "PATAGONIAN SIERRA FINCH",
|
343 |
+
"341": "PEACOCK",
|
344 |
+
"342": "PEREGRINE FALCON",
|
345 |
+
"343": "PHILIPPINE EAGLE",
|
346 |
+
"344": "PINK ROBIN",
|
347 |
+
"345": "POMARINE JAEGER",
|
348 |
+
"346": "PUFFIN",
|
349 |
+
"347": "PURPLE FINCH",
|
350 |
+
"348": "PURPLE GALLINULE",
|
351 |
+
"349": "PURPLE MARTIN",
|
352 |
+
"350": "PURPLE SWAMPHEN",
|
353 |
+
"351": "PYGMY KINGFISHER",
|
354 |
+
"352": "QUETZAL",
|
355 |
+
"353": "RAINBOW LORIKEET",
|
356 |
+
"354": "RAZORBILL",
|
357 |
+
"355": "RED BEARDED BEE EATER",
|
358 |
+
"356": "RED BELLIED PITTA",
|
359 |
+
"357": "RED BROWED FINCH",
|
360 |
+
"358": "RED FACED CORMORANT",
|
361 |
+
"359": "RED FACED WARBLER",
|
362 |
+
"360": "RED FODY",
|
363 |
+
"361": "RED HEADED DUCK",
|
364 |
+
"362": "RED HEADED WOODPECKER",
|
365 |
+
"363": "RED HONEY CREEPER",
|
366 |
+
"364": "RED NAPED TROGON",
|
367 |
+
"365": "RED TAILED HAWK",
|
368 |
+
"366": "RED TAILED THRUSH",
|
369 |
+
"367": "RED WINGED BLACKBIRD",
|
370 |
+
"368": "RED WISKERED BULBUL",
|
371 |
+
"369": "REGENT BOWERBIRD",
|
372 |
+
"370": "RING-NECKED PHEASANT",
|
373 |
+
"371": "ROADRUNNER",
|
374 |
+
"372": "ROBIN",
|
375 |
+
"373": "ROCK DOVE",
|
376 |
+
"374": "ROSY FACED LOVEBIRD",
|
377 |
+
"375": "ROUGH LEG BUZZARD",
|
378 |
+
"376": "ROYAL FLYCATCHER",
|
379 |
+
"377": "RUBY THROATED HUMMINGBIRD",
|
380 |
+
"378": "RUDY KINGFISHER",
|
381 |
+
"379": "RUFOUS KINGFISHER",
|
382 |
+
"380": "RUFUOS MOTMOT",
|
383 |
+
"381": "SAMATRAN THRUSH",
|
384 |
+
"382": "SAND MARTIN",
|
385 |
+
"383": "SANDHILL CRANE",
|
386 |
+
"384": "SATYR TRAGOPAN",
|
387 |
+
"385": "SCARLET CROWNED FRUIT DOVE",
|
388 |
+
"386": "SCARLET IBIS",
|
389 |
+
"387": "SCARLET MACAW",
|
390 |
+
"388": "SCARLET TANAGER",
|
391 |
+
"389": "SHOEBILL",
|
392 |
+
"390": "SHORT BILLED DOWITCHER",
|
393 |
+
"391": "SKUA",
|
394 |
+
"392": "SMITHS LONGSPUR",
|
395 |
+
"393": "SNOWY EGRET",
|
396 |
+
"394": "SNOWY OWL",
|
397 |
+
"395": "SNOWY PLOVER",
|
398 |
+
"396": "SORA",
|
399 |
+
"397": "SPANGLED COTINGA",
|
400 |
+
"398": "SPLENDID WREN",
|
401 |
+
"399": "SPOON BILED SANDPIPER",
|
402 |
+
"400": "SPOONBILL",
|
403 |
+
"401": "SPOTTED CATBIRD",
|
404 |
+
"402": "SRI LANKA BLUE MAGPIE",
|
405 |
+
"403": "STEAMER DUCK",
|
406 |
+
"404": "STORK BILLED KINGFISHER",
|
407 |
+
"405": "STRAWBERRY FINCH",
|
408 |
+
"406": "STRIPED OWL",
|
409 |
+
"407": "STRIPPED MANAKIN",
|
410 |
+
"408": "STRIPPED SWALLOW",
|
411 |
+
"409": "SUPERB STARLING",
|
412 |
+
"410": "SWINHOES PHEASANT",
|
413 |
+
"411": "TAILORBIRD",
|
414 |
+
"412": "TAIWAN MAGPIE",
|
415 |
+
"413": "TAKAHE",
|
416 |
+
"414": "TASMANIAN HEN",
|
417 |
+
"415": "TEAL DUCK",
|
418 |
+
"416": "TIT MOUSE",
|
419 |
+
"417": "TOUCHAN",
|
420 |
+
"418": "TOWNSENDS WARBLER",
|
421 |
+
"419": "TREE SWALLOW",
|
422 |
+
"420": "TRICOLORED BLACKBIRD",
|
423 |
+
"421": "TROPICAL KINGBIRD",
|
424 |
+
"422": "TRUMPTER SWAN",
|
425 |
+
"423": "TURKEY VULTURE",
|
426 |
+
"424": "TURQUOISE MOTMOT",
|
427 |
+
"425": "UMBRELLA BIRD",
|
428 |
+
"426": "VARIED THRUSH",
|
429 |
+
"427": "VEERY",
|
430 |
+
"428": "VENEZUELIAN TROUPIAL",
|
431 |
+
"429": "VERMILION FLYCATHER",
|
432 |
+
"430": "VICTORIA CROWNED PIGEON",
|
433 |
+
"431": "VIOLET GREEN SWALLOW",
|
434 |
+
"432": "VIOLET TURACO",
|
435 |
+
"433": "VULTURINE GUINEAFOWL",
|
436 |
+
"434": "WALL CREAPER",
|
437 |
+
"435": "WATTLED CURASSOW",
|
438 |
+
"436": "WATTLED LAPWING",
|
439 |
+
"437": "WHIMBREL",
|
440 |
+
"438": "WHITE BROWED CRAKE",
|
441 |
+
"439": "WHITE CHEEKED TURACO",
|
442 |
+
"440": "WHITE CRESTED HORNBILL",
|
443 |
+
"441": "WHITE NECKED RAVEN",
|
444 |
+
"442": "WHITE TAILED TROPIC",
|
445 |
+
"443": "WHITE THROATED BEE EATER",
|
446 |
+
"444": "WILD TURKEY",
|
447 |
+
"445": "WILSONS BIRD OF PARADISE",
|
448 |
+
"446": "WOOD DUCK",
|
449 |
+
"447": "YELLOW BELLIED FLOWERPECKER",
|
450 |
+
"448": "YELLOW CACIQUE",
|
451 |
+
"449": "YELLOW HEADED BLACKBIRD"
|
452 |
+
}
|
notebooks/onnx-testing.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
notebooks/pytorch-birds-resnet34.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
notebooks/torch-to-onnx.ipynb
ADDED
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"# Converting PyTorch to ONNX"
|
8 |
+
]
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"cell_type": "code",
|
12 |
+
"execution_count": 1,
|
13 |
+
"metadata": {},
|
14 |
+
"outputs": [
|
15 |
+
{
|
16 |
+
"name": "stderr",
|
17 |
+
"output_type": "stream",
|
18 |
+
"text": [
|
19 |
+
"/home/gautham/.local/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
20 |
+
" from .autonotebook import tqdm as notebook_tqdm\n"
|
21 |
+
]
|
22 |
+
}
|
23 |
+
],
|
24 |
+
"source": [
|
25 |
+
"import torch\n",
|
26 |
+
"from torch import nn\n",
|
27 |
+
"from torch.nn import functional as F"
|
28 |
+
]
|
29 |
+
},
|
30 |
+
{
|
31 |
+
"cell_type": "markdown",
|
32 |
+
"metadata": {},
|
33 |
+
"source": [
|
34 |
+
"## Defining the model"
|
35 |
+
]
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"cell_type": "code",
|
39 |
+
"execution_count": 2,
|
40 |
+
"metadata": {},
|
41 |
+
"outputs": [],
|
42 |
+
"source": [
|
43 |
+
"class DoubleConv(nn.Module):\n",
|
44 |
+
" def __init__(self, in_channels, out_channels):\n",
|
45 |
+
" super().__init__()\n",
|
46 |
+
" self.conv = nn.Sequential(\n",
|
47 |
+
" nn.Conv2d(in_channels, out_channels, 3, padding=1),\n",
|
48 |
+
" nn.BatchNorm2d(out_channels),\n",
|
49 |
+
" nn.ReLU(inplace=True),\n",
|
50 |
+
" nn.Conv2d(out_channels, out_channels, 3, padding=1),\n",
|
51 |
+
" nn.BatchNorm2d(out_channels),\n",
|
52 |
+
" nn.ReLU(inplace=True)\n",
|
53 |
+
" )\n",
|
54 |
+
" def forward(self, x):\n",
|
55 |
+
" return self.conv(x)\n",
|
56 |
+
"\n",
|
57 |
+
"class Down(nn.Module):\n",
|
58 |
+
" def __init__(self, in_channels, out_channels):\n",
|
59 |
+
" super().__init__()\n",
|
60 |
+
" self.down = nn.Sequential(\n",
|
61 |
+
" nn.MaxPool2d(2),\n",
|
62 |
+
" DoubleConv(in_channels, out_channels)\n",
|
63 |
+
" )\n",
|
64 |
+
" def forward(self, x):\n",
|
65 |
+
" return self.down(x)\n",
|
66 |
+
"\n",
|
67 |
+
"class Up(nn.Module):\n",
|
68 |
+
" def __init__(self, in_channels, out_channels, bilinear=False):\n",
|
69 |
+
" super().__init__()\n",
|
70 |
+
" if bilinear:\n",
|
71 |
+
" self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),\n",
|
72 |
+
" nn.Conv2d(in_channels, in_channels // 2, 1))\n",
|
73 |
+
" else:\n",
|
74 |
+
" self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, 2, stride=2)\n",
|
75 |
+
" \n",
|
76 |
+
" self.conv = DoubleConv(in_channels, out_channels)\n",
|
77 |
+
"\n",
|
78 |
+
" def forward(self, x1, x2):\n",
|
79 |
+
" x1 = self.up(x1)\n",
|
80 |
+
" diffY = x2.size()[2] - x1.size()[2]\n",
|
81 |
+
" diffX = x2.size()[3] - x1.size()[3]\n",
|
82 |
+
" x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])\n",
|
83 |
+
" x = torch.cat([x2, x1], dim=1)\n",
|
84 |
+
" return self.conv(x)\n",
|
85 |
+
"\n",
|
86 |
+
"class OutConv(nn.Module):\n",
|
87 |
+
" def __init__(self, in_channels, out_channels):\n",
|
88 |
+
" super().__init__()\n",
|
89 |
+
" self.conv = nn.Conv2d(in_channels, out_channels, 1)\n",
|
90 |
+
" self.sigmoid = nn.Sigmoid()\n",
|
91 |
+
"\n",
|
92 |
+
" def forward(self, x):\n",
|
93 |
+
" return self.sigmoid(self.conv(x))\n",
|
94 |
+
"\n",
|
95 |
+
"class UNet(nn.Module):\n",
|
96 |
+
" def __init__(self, n_channels, n_classes):\n",
|
97 |
+
" super().__init__()\n",
|
98 |
+
" self.inc = DoubleConv(n_channels, 64)\n",
|
99 |
+
" self.down1 = Down(64, 128)\n",
|
100 |
+
" self.down2 = Down(128, 256)\n",
|
101 |
+
" self.down3 = Down(256, 512)\n",
|
102 |
+
" self.down4 = Down(512, 1024)\n",
|
103 |
+
" self.up1 = Up(1024, 512)\n",
|
104 |
+
" self.up2 = Up(512, 256)\n",
|
105 |
+
" self.up3 = Up(256, 128)\n",
|
106 |
+
" self.up4 = Up(128, 64)\n",
|
107 |
+
" self.outc = OutConv(64, n_classes)\n",
|
108 |
+
"\n",
|
109 |
+
" def forward(self, x):\n",
|
110 |
+
" x1 = self.inc(x)\n",
|
111 |
+
" x2 = self.down1(x1)\n",
|
112 |
+
" x3 = self.down2(x2)\n",
|
113 |
+
" x4 = self.down3(x3)\n",
|
114 |
+
" x5 = self.down4(x4)\n",
|
115 |
+
" x = self.up1(x5, x4)\n",
|
116 |
+
" x = self.up2(x4, x3)\n",
|
117 |
+
" x = self.up3(x3, x2)\n",
|
118 |
+
" x = self.up4(x2, x1)\n",
|
119 |
+
" logits = self.outc(x)\n",
|
120 |
+
" return logits"
|
121 |
+
]
|
122 |
+
},
|
123 |
+
{
|
124 |
+
"cell_type": "markdown",
|
125 |
+
"metadata": {},
|
126 |
+
"source": [
|
127 |
+
"## Loading the model"
|
128 |
+
]
|
129 |
+
},
|
130 |
+
{
|
131 |
+
"cell_type": "code",
|
132 |
+
"execution_count": 3,
|
133 |
+
"metadata": {},
|
134 |
+
"outputs": [
|
135 |
+
{
|
136 |
+
"ename": "RuntimeError",
|
137 |
+
"evalue": "Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.",
|
138 |
+
"output_type": "error",
|
139 |
+
"traceback": [
|
140 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
141 |
+
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
142 |
+
"Cell \u001b[0;32mIn [3], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m model \u001b[39m=\u001b[39m UNet(n_channels\u001b[39m=\u001b[39m\u001b[39m3\u001b[39m, n_classes\u001b[39m=\u001b[39m\u001b[39m1\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m model\u001b[39m.\u001b[39mload_state_dict(torch\u001b[39m.\u001b[39;49mload(\u001b[39m'\u001b[39;49m\u001b[39m../weights/water_bodies_model.pth\u001b[39;49m\u001b[39m'\u001b[39;49m))\n\u001b[1;32m 3\u001b[0m model\u001b[39m.\u001b[39mto(\u001b[39m'\u001b[39m\u001b[39mcpu\u001b[39m\u001b[39m'\u001b[39m)\n\u001b[1;32m 5\u001b[0m dummy_input \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mrandn(\u001b[39m1\u001b[39m, \u001b[39m3\u001b[39m, \u001b[39m256\u001b[39m, \u001b[39m256\u001b[39m)\n",
|
143 |
+
"File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/serialization.py:789\u001b[0m, in \u001b[0;36mload\u001b[0;34m(f, map_location, pickle_module, weights_only, **pickle_load_args)\u001b[0m\n\u001b[1;32m 787\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m \u001b[39mas\u001b[39;00m e:\n\u001b[1;32m 788\u001b[0m \u001b[39mraise\u001b[39;00m pickle\u001b[39m.\u001b[39mUnpicklingError(UNSAFE_MESSAGE \u001b[39m+\u001b[39m \u001b[39mstr\u001b[39m(e)) \u001b[39mfrom\u001b[39;00m \u001b[39mNone\u001b[39m\n\u001b[0;32m--> 789\u001b[0m \u001b[39mreturn\u001b[39;00m _load(opened_zipfile, map_location, pickle_module, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mpickle_load_args)\n\u001b[1;32m 790\u001b[0m \u001b[39mif\u001b[39;00m weights_only:\n\u001b[1;32m 791\u001b[0m \u001b[39mtry\u001b[39;00m:\n",
|
144 |
+
"File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/serialization.py:1131\u001b[0m, in \u001b[0;36m_load\u001b[0;34m(zip_file, map_location, pickle_module, pickle_file, **pickle_load_args)\u001b[0m\n\u001b[1;32m 1129\u001b[0m unpickler \u001b[39m=\u001b[39m UnpicklerWrapper(data_file, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mpickle_load_args)\n\u001b[1;32m 1130\u001b[0m unpickler\u001b[39m.\u001b[39mpersistent_load \u001b[39m=\u001b[39m persistent_load\n\u001b[0;32m-> 1131\u001b[0m result \u001b[39m=\u001b[39m unpickler\u001b[39m.\u001b[39;49mload()\n\u001b[1;32m 1133\u001b[0m torch\u001b[39m.\u001b[39m_utils\u001b[39m.\u001b[39m_validate_loaded_sparse_tensors()\n\u001b[1;32m 1135\u001b[0m \u001b[39mreturn\u001b[39;00m result\n",
|
145 |
+
"File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/serialization.py:1101\u001b[0m, in \u001b[0;36m_load.<locals>.persistent_load\u001b[0;34m(saved_id)\u001b[0m\n\u001b[1;32m 1099\u001b[0m \u001b[39mif\u001b[39;00m key \u001b[39mnot\u001b[39;00m \u001b[39min\u001b[39;00m loaded_storages:\n\u001b[1;32m 1100\u001b[0m nbytes \u001b[39m=\u001b[39m numel \u001b[39m*\u001b[39m torch\u001b[39m.\u001b[39m_utils\u001b[39m.\u001b[39m_element_size(dtype)\n\u001b[0;32m-> 1101\u001b[0m load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location))\n\u001b[1;32m 1103\u001b[0m \u001b[39mreturn\u001b[39;00m loaded_storages[key]\n",
|
146 |
+
"File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/serialization.py:1083\u001b[0m, in \u001b[0;36m_load.<locals>.load_tensor\u001b[0;34m(dtype, numel, key, location)\u001b[0m\n\u001b[1;32m 1079\u001b[0m storage \u001b[39m=\u001b[39m zip_file\u001b[39m.\u001b[39mget_storage_from_record(name, numel, torch\u001b[39m.\u001b[39mUntypedStorage)\u001b[39m.\u001b[39mstorage()\u001b[39m.\u001b[39muntyped()\n\u001b[1;32m 1080\u001b[0m \u001b[39m# TODO: Once we decide to break serialization FC, we can\u001b[39;00m\n\u001b[1;32m 1081\u001b[0m \u001b[39m# stop wrapping with TypedStorage\u001b[39;00m\n\u001b[1;32m 1082\u001b[0m loaded_storages[key] \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mstorage\u001b[39m.\u001b[39mTypedStorage(\n\u001b[0;32m-> 1083\u001b[0m wrap_storage\u001b[39m=\u001b[39mrestore_location(storage, location),\n\u001b[1;32m 1084\u001b[0m dtype\u001b[39m=\u001b[39mdtype)\n",
|
147 |
+
"File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/serialization.py:215\u001b[0m, in \u001b[0;36mdefault_restore_location\u001b[0;34m(storage, location)\u001b[0m\n\u001b[1;32m 213\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mdefault_restore_location\u001b[39m(storage, location):\n\u001b[1;32m 214\u001b[0m \u001b[39mfor\u001b[39;00m _, _, fn \u001b[39min\u001b[39;00m _package_registry:\n\u001b[0;32m--> 215\u001b[0m result \u001b[39m=\u001b[39m fn(storage, location)\n\u001b[1;32m 216\u001b[0m \u001b[39mif\u001b[39;00m result \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 217\u001b[0m \u001b[39mreturn\u001b[39;00m result\n",
|
148 |
+
"File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/serialization.py:182\u001b[0m, in \u001b[0;36m_cuda_deserialize\u001b[0;34m(obj, location)\u001b[0m\n\u001b[1;32m 180\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_cuda_deserialize\u001b[39m(obj, location):\n\u001b[1;32m 181\u001b[0m \u001b[39mif\u001b[39;00m location\u001b[39m.\u001b[39mstartswith(\u001b[39m'\u001b[39m\u001b[39mcuda\u001b[39m\u001b[39m'\u001b[39m):\n\u001b[0;32m--> 182\u001b[0m device \u001b[39m=\u001b[39m validate_cuda_device(location)\n\u001b[1;32m 183\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mgetattr\u001b[39m(obj, \u001b[39m\"\u001b[39m\u001b[39m_torch_load_uninitialized\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39mFalse\u001b[39;00m):\n\u001b[1;32m 184\u001b[0m \u001b[39mwith\u001b[39;00m torch\u001b[39m.\u001b[39mcuda\u001b[39m.\u001b[39mdevice(device):\n",
|
149 |
+
"File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/serialization.py:166\u001b[0m, in \u001b[0;36mvalidate_cuda_device\u001b[0;34m(location)\u001b[0m\n\u001b[1;32m 163\u001b[0m device \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mcuda\u001b[39m.\u001b[39m_utils\u001b[39m.\u001b[39m_get_device_index(location, \u001b[39mTrue\u001b[39;00m)\n\u001b[1;32m 165\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m torch\u001b[39m.\u001b[39mcuda\u001b[39m.\u001b[39mis_available():\n\u001b[0;32m--> 166\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\u001b[39m'\u001b[39m\u001b[39mAttempting to deserialize object on a CUDA \u001b[39m\u001b[39m'\u001b[39m\n\u001b[1;32m 167\u001b[0m \u001b[39m'\u001b[39m\u001b[39mdevice but torch.cuda.is_available() is False. \u001b[39m\u001b[39m'\u001b[39m\n\u001b[1;32m 168\u001b[0m \u001b[39m'\u001b[39m\u001b[39mIf you are running on a CPU-only machine, \u001b[39m\u001b[39m'\u001b[39m\n\u001b[1;32m 169\u001b[0m \u001b[39m'\u001b[39m\u001b[39mplease use torch.load with map_location=torch.device(\u001b[39m\u001b[39m\\'\u001b[39;00m\u001b[39mcpu\u001b[39m\u001b[39m\\'\u001b[39;00m\u001b[39m) \u001b[39m\u001b[39m'\u001b[39m\n\u001b[1;32m 170\u001b[0m \u001b[39m'\u001b[39m\u001b[39mto map your storages to the CPU.\u001b[39m\u001b[39m'\u001b[39m)\n\u001b[1;32m 171\u001b[0m device_count \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mcuda\u001b[39m.\u001b[39mdevice_count()\n\u001b[1;32m 172\u001b[0m \u001b[39mif\u001b[39;00m device \u001b[39m>\u001b[39m\u001b[39m=\u001b[39m device_count:\n",
|
150 |
+
"\u001b[0;31mRuntimeError\u001b[0m: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU."
|
151 |
+
]
|
152 |
+
}
|
153 |
+
],
|
154 |
+
"source": [
|
155 |
+
"model = UNet(n_channels=3, n_classes=1)\n",
|
156 |
+
"model.load_state_dict(torch.load('../weights/water_bodies_model.pth'))\n",
|
157 |
+
"\n",
|
158 |
+
"dummy_input = torch.randn(1, 3, 256, 256)\n",
|
159 |
+
"\n",
|
160 |
+
"model.eval()\n",
|
161 |
+
"torch_out = model(dummy_input)"
|
162 |
+
]
|
163 |
+
},
|
164 |
+
{
|
165 |
+
"cell_type": "markdown",
|
166 |
+
"metadata": {},
|
167 |
+
"source": [
|
168 |
+
"## Converting to ONNX"
|
169 |
+
]
|
170 |
+
},
|
171 |
+
{
|
172 |
+
"cell_type": "code",
|
173 |
+
"execution_count": null,
|
174 |
+
"metadata": {},
|
175 |
+
"outputs": [
|
176 |
+
{
|
177 |
+
"name": "stdout",
|
178 |
+
"output_type": "stream",
|
179 |
+
"text": [
|
180 |
+
"Exported graph: graph(%input : Float(*, 3, 256, 256, strides=[196608, 65536, 256, 1], requires_grad=0, device=cpu),\n",
|
181 |
+
" %up4.up.weight : Float(128, 64, 2, 2, strides=[256, 4, 2, 1], requires_grad=1, device=cpu),\n",
|
182 |
+
" %up4.up.bias : Float(64, strides=[1], requires_grad=1, device=cpu),\n",
|
183 |
+
" %outc.conv.weight : Float(1, 64, 1, 1, strides=[64, 1, 1, 1], requires_grad=1, device=cpu),\n",
|
184 |
+
" %outc.conv.bias : Float(1, strides=[1], requires_grad=1, device=cpu),\n",
|
185 |
+
" %onnx::Conv_226 : Float(64, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=0, device=cpu),\n",
|
186 |
+
" %onnx::Conv_227 : Float(64, strides=[1], requires_grad=0, device=cpu),\n",
|
187 |
+
" %onnx::Conv_229 : Float(64, 64, 3, 3, strides=[576, 9, 3, 1], requires_grad=0, device=cpu),\n",
|
188 |
+
" %onnx::Conv_230 : Float(64, strides=[1], requires_grad=0, device=cpu),\n",
|
189 |
+
" %onnx::Conv_232 : Float(128, 64, 3, 3, strides=[576, 9, 3, 1], requires_grad=0, device=cpu),\n",
|
190 |
+
" %onnx::Conv_233 : Float(128, strides=[1], requires_grad=0, device=cpu),\n",
|
191 |
+
" %onnx::Conv_235 : Float(128, 128, 3, 3, strides=[1152, 9, 3, 1], requires_grad=0, device=cpu),\n",
|
192 |
+
" %onnx::Conv_236 : Float(128, strides=[1], requires_grad=0, device=cpu),\n",
|
193 |
+
" %onnx::Conv_238 : Float(64, 128, 3, 3, strides=[1152, 9, 3, 1], requires_grad=0, device=cpu),\n",
|
194 |
+
" %onnx::Conv_239 : Float(64, strides=[1], requires_grad=0, device=cpu),\n",
|
195 |
+
" %onnx::Conv_241 : Float(64, 64, 3, 3, strides=[576, 9, 3, 1], requires_grad=0, device=cpu),\n",
|
196 |
+
" %onnx::Conv_242 : Float(64, strides=[1], requires_grad=0, device=cpu)):\n",
|
197 |
+
" %/inc/conv/conv.0/Conv_output_0 : Float(*, 64, 256, 256, strides=[4194304, 65536, 256, 1], requires_grad=1, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1], onnx_name=\"/inc/conv/conv.0/Conv\"](%input, %onnx::Conv_226, %onnx::Conv_227), scope: __main__.UNet::/__main__.DoubleConv::inc/torch.nn.modules.container.Sequential::conv/torch.nn.modules.conv.Conv2d::conv.0 # /home/gautham/.local/lib/python3.10/site-packages/torch/nn/modules/conv.py:458:0\n",
|
198 |
+
" %/inc/conv/conv.2/Relu_output_0 : Float(*, 64, 256, 256, strides=[4194304, 65536, 256, 1], requires_grad=1, device=cpu) = onnx::Relu[onnx_name=\"/inc/conv/conv.2/Relu\"](%/inc/conv/conv.0/Conv_output_0), scope: __main__.UNet::/__main__.DoubleConv::inc/torch.nn.modules.container.Sequential::conv/torch.nn.modules.activation.ReLU::conv.2 # /home/gautham/.local/lib/python3.10/site-packages/torch/nn/functional.py:1453:0\n",
|
199 |
+
" %/inc/conv/conv.3/Conv_output_0 : Float(*, 64, 256, 256, strides=[4194304, 65536, 256, 1], requires_grad=1, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1], onnx_name=\"/inc/conv/conv.3/Conv\"](%/inc/conv/conv.2/Relu_output_0, %onnx::Conv_229, %onnx::Conv_230), scope: __main__.UNet::/__main__.DoubleConv::inc/torch.nn.modules.container.Sequential::conv/torch.nn.modules.conv.Conv2d::conv.3 # /home/gautham/.local/lib/python3.10/site-packages/torch/nn/modules/conv.py:458:0\n",
|
200 |
+
" %/inc/conv/conv.5/Relu_output_0 : Float(*, 64, 256, 256, strides=[4194304, 65536, 256, 1], requires_grad=1, device=cpu) = onnx::Relu[onnx_name=\"/inc/conv/conv.5/Relu\"](%/inc/conv/conv.3/Conv_output_0), scope: __main__.UNet::/__main__.DoubleConv::inc/torch.nn.modules.container.Sequential::conv/torch.nn.modules.activation.ReLU::conv.5 # /home/gautham/.local/lib/python3.10/site-packages/torch/nn/functional.py:1453:0\n",
|
201 |
+
" %/down1/down/down.0/MaxPool_output_0 : Float(*, 64, 128, 128, strides=[1048576, 16384, 128, 1], requires_grad=1, device=cpu) = onnx::MaxPool[ceil_mode=0, kernel_shape=[2, 2], pads=[0, 0, 0, 0], strides=[2, 2], onnx_name=\"/down1/down/down.0/MaxPool\"](%/inc/conv/conv.5/Relu_output_0), scope: __main__.UNet::/__main__.Down::down1/torch.nn.modules.container.Sequential::down/torch.nn.modules.pooling.MaxPool2d::down.0 # /home/gautham/.local/lib/python3.10/site-packages/torch/nn/functional.py:780:0\n",
|
202 |
+
" %/down1/down/down.1/conv/conv.0/Conv_output_0 : Float(*, 128, 128, 128, strides=[2097152, 16384, 128, 1], requires_grad=1, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1], onnx_name=\"/down1/down/down.1/conv/conv.0/Conv\"](%/down1/down/down.0/MaxPool_output_0, %onnx::Conv_232, %onnx::Conv_233), scope: __main__.UNet::/__main__.Down::down1/torch.nn.modules.container.Sequential::down/__main__.DoubleConv::down.1/torch.nn.modules.container.Sequential::conv/torch.nn.modules.conv.Conv2d::conv.0 # /home/gautham/.local/lib/python3.10/site-packages/torch/nn/modules/conv.py:458:0\n",
|
203 |
+
" %/down1/down/down.1/conv/conv.2/Relu_output_0 : Float(*, 128, 128, 128, strides=[2097152, 16384, 128, 1], requires_grad=1, device=cpu) = onnx::Relu[onnx_name=\"/down1/down/down.1/conv/conv.2/Relu\"](%/down1/down/down.1/conv/conv.0/Conv_output_0), scope: __main__.UNet::/__main__.Down::down1/torch.nn.modules.container.Sequential::down/__main__.DoubleConv::down.1/torch.nn.modules.container.Sequential::conv/torch.nn.modules.activation.ReLU::conv.2 # /home/gautham/.local/lib/python3.10/site-packages/torch/nn/functional.py:1453:0\n",
|
204 |
+
" %/down1/down/down.1/conv/conv.3/Conv_output_0 : Float(*, 128, 128, 128, strides=[2097152, 16384, 128, 1], requires_grad=1, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1], onnx_name=\"/down1/down/down.1/conv/conv.3/Conv\"](%/down1/down/down.1/conv/conv.2/Relu_output_0, %onnx::Conv_235, %onnx::Conv_236), scope: __main__.UNet::/__main__.Down::down1/torch.nn.modules.container.Sequential::down/__main__.DoubleConv::down.1/torch.nn.modules.container.Sequential::conv/torch.nn.modules.conv.Conv2d::conv.3 # /home/gautham/.local/lib/python3.10/site-packages/torch/nn/modules/conv.py:458:0\n",
|
205 |
+
" %/down1/down/down.1/conv/conv.5/Relu_output_0 : Float(*, 128, 128, 128, strides=[2097152, 16384, 128, 1], requires_grad=1, device=cpu) = onnx::Relu[onnx_name=\"/down1/down/down.1/conv/conv.5/Relu\"](%/down1/down/down.1/conv/conv.3/Conv_output_0), scope: __main__.UNet::/__main__.Down::down1/torch.nn.modules.container.Sequential::down/__main__.DoubleConv::down.1/torch.nn.modules.container.Sequential::conv/torch.nn.modules.activation.ReLU::conv.5 # /home/gautham/.local/lib/python3.10/site-packages/torch/nn/functional.py:1453:0\n",
|
206 |
+
" %/up4/up/ConvTranspose_output_0 : Float(*, 64, 256, 256, strides=[4194304, 65536, 256, 1], requires_grad=0, device=cpu) = onnx::ConvTranspose[dilations=[1, 1], group=1, kernel_shape=[2, 2], pads=[0, 0, 0, 0], strides=[2, 2], onnx_name=\"/up4/up/ConvTranspose\"](%/down1/down/down.1/conv/conv.5/Relu_output_0, %up4.up.weight, %up4.up.bias), scope: __main__.UNet::/__main__.Up::up4/torch.nn.modules.conv.ConvTranspose2d::up # /home/gautham/.local/lib/python3.10/site-packages/torch/nn/modules/conv.py:953:0\n",
|
207 |
+
" %/up4/Shape_output_0 : Long(4, strides=[1], device=cpu) = onnx::Shape[onnx_name=\"/up4/Shape\"](%/inc/conv/conv.5/Relu_output_0), scope: __main__.UNet::/__main__.Up::up4 # /tmp/ipykernel_19565/1651252862.py:37:0\n",
|
208 |
+
" %/up4/Constant_output_0 : Long(device=cpu) = onnx::Constant[value={2}, onnx_name=\"/up4/Constant\"](), scope: __main__.UNet::/__main__.Up::up4 # /tmp/ipykernel_19565/1651252862.py:37:0\n",
|
209 |
+
" %/up4/Gather_output_0 : Long(device=cpu) = onnx::Gather[axis=0, onnx_name=\"/up4/Gather\"](%/up4/Shape_output_0, %/up4/Constant_output_0), scope: __main__.UNet::/__main__.Up::up4 # /tmp/ipykernel_19565/1651252862.py:37:0\n",
|
210 |
+
" %/up4/Shape_1_output_0 : Long(4, strides=[1], device=cpu) = onnx::Shape[onnx_name=\"/up4/Shape_1\"](%/up4/up/ConvTranspose_output_0), scope: __main__.UNet::/__main__.Up::up4 # /tmp/ipykernel_19565/1651252862.py:38:0\n",
|
211 |
+
" %/up4/Constant_1_output_0 : Long(device=cpu) = onnx::Constant[value={2}, onnx_name=\"/up4/Constant_1\"](), scope: __main__.UNet::/__main__.Up::up4 # /tmp/ipykernel_19565/1651252862.py:38:0\n",
|
212 |
+
" %/up4/Gather_1_output_0 : Long(device=cpu) = onnx::Gather[axis=0, onnx_name=\"/up4/Gather_1\"](%/up4/Shape_1_output_0, %/up4/Constant_1_output_0), scope: __main__.UNet::/__main__.Up::up4 # /tmp/ipykernel_19565/1651252862.py:38:0\n",
|
213 |
+
" %/up4/Sub_output_0 : Long(requires_grad=0, device=cpu) = onnx::Sub[onnx_name=\"/up4/Sub\"](%/up4/Gather_output_0, %/up4/Gather_1_output_0), scope: __main__.UNet::/__main__.Up::up4 # /tmp/ipykernel_19565/1651252862.py:38:0\n",
|
214 |
+
" %/up4/Shape_2_output_0 : Long(4, strides=[1], device=cpu) = onnx::Shape[onnx_name=\"/up4/Shape_2\"](%/inc/conv/conv.5/Relu_output_0), scope: __main__.UNet::/__main__.Up::up4 # /tmp/ipykernel_19565/1651252862.py:38:0\n",
|
215 |
+
" %/up4/Constant_2_output_0 : Long(device=cpu) = onnx::Constant[value={3}, onnx_name=\"/up4/Constant_2\"](), scope: __main__.UNet::/__main__.Up::up4 # /tmp/ipykernel_19565/1651252862.py:38:0\n",
|
216 |
+
" %/up4/Gather_2_output_0 : Long(device=cpu) = onnx::Gather[axis=0, onnx_name=\"/up4/Gather_2\"](%/up4/Shape_2_output_0, %/up4/Constant_2_output_0), scope: __main__.UNet::/__main__.Up::up4 # /tmp/ipykernel_19565/1651252862.py:38:0\n",
|
217 |
+
" %/up4/Shape_3_output_0 : Long(4, strides=[1], device=cpu) = onnx::Shape[onnx_name=\"/up4/Shape_3\"](%/up4/up/ConvTranspose_output_0), scope: __main__.UNet::/__main__.Up::up4 # /tmp/ipykernel_19565/1651252862.py:38:0\n",
|
218 |
+
" %/up4/Constant_3_output_0 : Long(device=cpu) = onnx::Constant[value={3}, onnx_name=\"/up4/Constant_3\"](), scope: __main__.UNet::/__main__.Up::up4 # /tmp/ipykernel_19565/1651252862.py:38:0\n",
|
219 |
+
" %/up4/Gather_3_output_0 : Long(device=cpu) = onnx::Gather[axis=0, onnx_name=\"/up4/Gather_3\"](%/up4/Shape_3_output_0, %/up4/Constant_3_output_0), scope: __main__.UNet::/__main__.Up::up4 # /tmp/ipykernel_19565/1651252862.py:38:0\n",
|
220 |
+
" %/up4/Sub_1_output_0 : Long(requires_grad=0, device=cpu) = onnx::Sub[onnx_name=\"/up4/Sub_1\"](%/up4/Gather_2_output_0, %/up4/Gather_3_output_0), scope: __main__.UNet::/__main__.Up::up4 # /tmp/ipykernel_19565/1651252862.py:38:0\n",
|
221 |
+
" %/up4/Constant_4_output_0 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={2}, onnx_name=\"/up4/Constant_4\"](), scope: __main__.UNet::/__main__.Up::up4 # /home/gautham/.local/lib/python3.10/site-packages/torch/_tensor.py:867:0\n",
|
222 |
+
" %/up4/Div_output_0 : Long(device=cpu) = onnx::Div[onnx_name=\"/up4/Div\"](%/up4/Sub_1_output_0, %/up4/Constant_4_output_0), scope: __main__.UNet::/__main__.Up::up4 # /home/gautham/.local/lib/python3.10/site-packages/torch/_tensor.py:867:0\n",
|
223 |
+
" %/up4/Cast_output_0 : Long(device=cpu) = onnx::Cast[to=7, onnx_name=\"/up4/Cast\"](%/up4/Div_output_0), scope: __main__.UNet::/__main__.Up::up4 # /home/gautham/.local/lib/python3.10/site-packages/torch/_tensor.py:867:0\n",
|
224 |
+
" %/up4/Cast_1_output_0 : Long(requires_grad=0, device=cpu) = onnx::Cast[to=7, onnx_name=\"/up4/Cast_1\"](%/up4/Cast_output_0), scope: __main__.UNet::/__main__.Up::up4 # /home/gautham/.local/lib/python3.10/site-packages/torch/_tensor.py:867:0\n",
|
225 |
+
" %/up4/Sub_2_output_0 : Long(requires_grad=0, device=cpu) = onnx::Sub[onnx_name=\"/up4/Sub_2\"](%/up4/Sub_1_output_0, %/up4/Cast_1_output_0), scope: __main__.UNet::/__main__.Up::up4 # /tmp/ipykernel_19565/1651252862.py:39:0\n",
|
226 |
+
" %/up4/Constant_5_output_0 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={2}, onnx_name=\"/up4/Constant_5\"](), scope: __main__.UNet::/__main__.Up::up4 # /home/gautham/.local/lib/python3.10/site-packages/torch/_tensor.py:867:0\n",
|
227 |
+
" %/up4/Div_1_output_0 : Long(device=cpu) = onnx::Div[onnx_name=\"/up4/Div_1\"](%/up4/Sub_output_0, %/up4/Constant_5_output_0), scope: __main__.UNet::/__main__.Up::up4 # /home/gautham/.local/lib/python3.10/site-packages/torch/_tensor.py:867:0\n",
|
228 |
+
" %/up4/Cast_2_output_0 : Long(device=cpu) = onnx::Cast[to=7, onnx_name=\"/up4/Cast_2\"](%/up4/Div_1_output_0), scope: __main__.UNet::/__main__.Up::up4 # /home/gautham/.local/lib/python3.10/site-packages/torch/_tensor.py:867:0\n",
|
229 |
+
" %/up4/Cast_3_output_0 : Long(requires_grad=0, device=cpu) = onnx::Cast[to=7, onnx_name=\"/up4/Cast_3\"](%/up4/Cast_2_output_0), scope: __main__.UNet::/__main__.Up::up4 # /home/gautham/.local/lib/python3.10/site-packages/torch/_tensor.py:867:0\n",
|
230 |
+
" %/up4/Sub_3_output_0 : Long(requires_grad=0, device=cpu) = onnx::Sub[onnx_name=\"/up4/Sub_3\"](%/up4/Sub_output_0, %/up4/Cast_3_output_0), scope: __main__.UNet::/__main__.Up::up4 # /tmp/ipykernel_19565/1651252862.py:39:0\n",
|
231 |
+
" %onnx::Unsqueeze_175 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}]()\n",
|
232 |
+
" %/up4/Unsqueeze_output_0 : Long(1, strides=[1], device=cpu) = onnx::Unsqueeze[onnx_name=\"/up4/Unsqueeze\"](%/up4/Cast_1_output_0, %onnx::Unsqueeze_175), scope: __main__.UNet::/__main__.Up::up4\n",
|
233 |
+
" %onnx::Unsqueeze_177 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}]()\n",
|
234 |
+
" %/up4/Unsqueeze_1_output_0 : Long(1, strides=[1], device=cpu) = onnx::Unsqueeze[onnx_name=\"/up4/Unsqueeze_1\"](%/up4/Sub_2_output_0, %onnx::Unsqueeze_177), scope: __main__.UNet::/__main__.Up::up4\n",
|
235 |
+
" %onnx::Unsqueeze_179 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}]()\n",
|
236 |
+
" %/up4/Unsqueeze_2_output_0 : Long(1, strides=[1], device=cpu) = onnx::Unsqueeze[onnx_name=\"/up4/Unsqueeze_2\"](%/up4/Cast_3_output_0, %onnx::Unsqueeze_179), scope: __main__.UNet::/__main__.Up::up4\n",
|
237 |
+
" %onnx::Unsqueeze_181 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}]()\n",
|
238 |
+
" %/up4/Unsqueeze_3_output_0 : Long(1, strides=[1], device=cpu) = onnx::Unsqueeze[onnx_name=\"/up4/Unsqueeze_3\"](%/up4/Sub_3_output_0, %onnx::Unsqueeze_181), scope: __main__.UNet::/__main__.Up::up4\n",
|
239 |
+
" %/up4/Concat_output_0 : Long(4, strides=[1], device=cpu) = onnx::Concat[axis=0, onnx_name=\"/up4/Concat\"](%/up4/Unsqueeze_output_0, %/up4/Unsqueeze_1_output_0, %/up4/Unsqueeze_2_output_0, %/up4/Unsqueeze_3_output_0), scope: __main__.UNet::/__main__.Up::up4 # /tmp/ipykernel_19565/1651252862.py:39:0\n",
|
240 |
+
" %onnx::Unsqueeze_184 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}]()\n",
|
241 |
+
" %/up4/Unsqueeze_4_output_0 : Long(1, strides=[1], device=cpu) = onnx::Unsqueeze[onnx_name=\"/up4/Unsqueeze_4\"](%/up4/Cast_1_output_0, %onnx::Unsqueeze_184), scope: __main__.UNet::/__main__.Up::up4\n",
|
242 |
+
" %onnx::Unsqueeze_186 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}]()\n",
|
243 |
+
" %/up4/Unsqueeze_5_output_0 : Long(1, strides=[1], device=cpu) = onnx::Unsqueeze[onnx_name=\"/up4/Unsqueeze_5\"](%/up4/Sub_2_output_0, %onnx::Unsqueeze_186), scope: __main__.UNet::/__main__.Up::up4\n",
|
244 |
+
" %onnx::Unsqueeze_188 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}]()\n",
|
245 |
+
" %/up4/Unsqueeze_6_output_0 : Long(1, strides=[1], device=cpu) = onnx::Unsqueeze[onnx_name=\"/up4/Unsqueeze_6\"](%/up4/Cast_3_output_0, %onnx::Unsqueeze_188), scope: __main__.UNet::/__main__.Up::up4\n",
|
246 |
+
" %onnx::Unsqueeze_190 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}]()\n",
|
247 |
+
" %/up4/Unsqueeze_7_output_0 : Long(1, strides=[1], device=cpu) = onnx::Unsqueeze[onnx_name=\"/up4/Unsqueeze_7\"](%/up4/Sub_3_output_0, %onnx::Unsqueeze_190), scope: __main__.UNet::/__main__.Up::up4\n",
|
248 |
+
" %/up4/Concat_1_output_0 : Long(4, strides=[1], device=cpu) = onnx::Concat[axis=0, onnx_name=\"/up4/Concat_1\"](%/up4/Unsqueeze_4_output_0, %/up4/Unsqueeze_5_output_0, %/up4/Unsqueeze_6_output_0, %/up4/Unsqueeze_7_output_0), scope: __main__.UNet::/__main__.Up::up4 # /tmp/ipykernel_19565/1651252862.py:39:0\n",
|
249 |
+
" %onnx::Pad_193 : NoneType = prim::Constant(), scope: __main__.UNet::/__main__.Up::up4\n",
|
250 |
+
" %/up4/Constant_6_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}, onnx_name=\"/up4/Constant_6\"](), scope: __main__.UNet::/__main__.Up::up4 # /tmp/ipykernel_19565/1651252862.py:39:0\n",
|
251 |
+
" %/up4/Shape_4_output_0 : Long(1, strides=[1], device=cpu) = onnx::Shape[onnx_name=\"/up4/Shape_4\"](%/up4/Concat_output_0), scope: __main__.UNet::/__main__.Up::up4 # /tmp/ipykernel_19565/1651252862.py:39:0\n",
|
252 |
+
" %/up4/Gather_4_output_0 : Long(1, strides=[1], device=cpu) = onnx::Gather[axis=0, onnx_name=\"/up4/Gather_4\"](%/up4/Shape_4_output_0, %/up4/Constant_6_output_0), scope: __main__.UNet::/__main__.Up::up4 # /tmp/ipykernel_19565/1651252862.py:39:0\n",
|
253 |
+
" %/up4/Constant_7_output_0 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={8}, onnx_name=\"/up4/Constant_7\"](), scope: __main__.UNet::/__main__.Up::up4 # /tmp/ipykernel_19565/1651252862.py:39:0\n",
|
254 |
+
" %/up4/Sub_4_output_0 : Long(1, strides=[1], device=cpu) = onnx::Sub[onnx_name=\"/up4/Sub_4\"](%/up4/Constant_7_output_0, %/up4/Gather_4_output_0), scope: __main__.UNet::/__main__.Up::up4 # /tmp/ipykernel_19565/1651252862.py:39:0\n",
|
255 |
+
" %/up4/Cast_4_output_0 : Long(4, strides=[1], device=cpu) = onnx::Cast[to=7, onnx_name=\"/up4/Cast_4\"](%/up4/Concat_1_output_0), scope: __main__.UNet::/__main__.Up::up4 # /tmp/ipykernel_19565/1651252862.py:39:0\n",
|
256 |
+
" %/up4/ConstantOfShape_output_0 : Long(4, strides=[1], device=cpu) = onnx::ConstantOfShape[value={0}, onnx_name=\"/up4/ConstantOfShape\"](%/up4/Sub_4_output_0), scope: __main__.UNet::/__main__.Up::up4 # /tmp/ipykernel_19565/1651252862.py:39:0\n",
|
257 |
+
" %/up4/Concat_2_output_0 : Long(8, strides=[1], device=cpu) = onnx::Concat[axis=0, onnx_name=\"/up4/Concat_2\"](%/up4/Cast_4_output_0, %/up4/ConstantOfShape_output_0), scope: __main__.UNet::/__main__.Up::up4 # /tmp/ipykernel_19565/1651252862.py:39:0\n",
|
258 |
+
" %/up4/Constant_8_output_0 : Long(2, strides=[1], device=cpu) = onnx::Constant[value=-1 2 [ CPULongType{2} ], onnx_name=\"/up4/Constant_8\"](), scope: __main__.UNet::/__main__.Up::up4 # /tmp/ipykernel_19565/1651252862.py:39:0\n",
|
259 |
+
" %/up4/Reshape_output_0 : Long(4, 2, strides=[2, 1], device=cpu) = onnx::Reshape[allowzero=0, onnx_name=\"/up4/Reshape\"](%/up4/Concat_2_output_0, %/up4/Constant_8_output_0), scope: __main__.UNet::/__main__.Up::up4 # /tmp/ipykernel_19565/1651252862.py:39:0\n",
|
260 |
+
" %/up4/Constant_9_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}, onnx_name=\"/up4/Constant_9\"](), scope: __main__.UNet::/__main__.Up::up4 # /tmp/ipykernel_19565/1651252862.py:39:0\n",
|
261 |
+
" %/up4/Constant_10_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={-1}, onnx_name=\"/up4/Constant_10\"](), scope: __main__.UNet::/__main__.Up::up4 # /tmp/ipykernel_19565/1651252862.py:39:0\n",
|
262 |
+
" %/up4/Constant_11_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={-9223372036854775807}, onnx_name=\"/up4/Constant_11\"](), scope: __main__.UNet::/__main__.Up::up4 # /tmp/ipykernel_19565/1651252862.py:39:0\n",
|
263 |
+
" %/up4/Constant_12_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={-1}, onnx_name=\"/up4/Constant_12\"](), scope: __main__.UNet::/__main__.Up::up4 # /tmp/ipykernel_19565/1651252862.py:39:0\n",
|
264 |
+
" %/up4/Slice_output_0 : Long(4, 2, strides=[2, 1], device=cpu) = onnx::Slice[onnx_name=\"/up4/Slice\"](%/up4/Reshape_output_0, %/up4/Constant_10_output_0, %/up4/Constant_11_output_0, %/up4/Constant_9_output_0, %/up4/Constant_12_output_0), scope: __main__.UNet::/__main__.Up::up4 # /tmp/ipykernel_19565/1651252862.py:39:0\n",
|
265 |
+
" %/up4/Transpose_output_0 : Long(2, 4, strides=[4, 1], device=cpu) = onnx::Transpose[perm=[1, 0], onnx_name=\"/up4/Transpose\"](%/up4/Slice_output_0), scope: __main__.UNet::/__main__.Up::up4 # /tmp/ipykernel_19565/1651252862.py:39:0\n",
|
266 |
+
" %/up4/Constant_13_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={-1}, onnx_name=\"/up4/Constant_13\"](), scope: __main__.UNet::/__main__.Up::up4 # /tmp/ipykernel_19565/1651252862.py:39:0\n",
|
267 |
+
" %/up4/Reshape_1_output_0 : Long(8, strides=[1], device=cpu) = onnx::Reshape[allowzero=0, onnx_name=\"/up4/Reshape_1\"](%/up4/Transpose_output_0, %/up4/Constant_13_output_0), scope: __main__.UNet::/__main__.Up::up4 # /tmp/ipykernel_19565/1651252862.py:39:0\n",
|
268 |
+
" %/up4/Cast_5_output_0 : Long(8, strides=[1], device=cpu) = onnx::Cast[to=7, onnx_name=\"/up4/Cast_5\"](%/up4/Reshape_1_output_0), scope: __main__.UNet::/__main__.Up::up4 # /tmp/ipykernel_19565/1651252862.py:39:0\n",
|
269 |
+
" %/up4/Pad_output_0 : Float(*, *, *, *, strides=[4194304, 65536, 256, 1], requires_grad=1, device=cpu) = onnx::Pad[mode=\"constant\", onnx_name=\"/up4/Pad\"](%/up4/up/ConvTranspose_output_0, %/up4/Cast_5_output_0, %onnx::Pad_193), scope: __main__.UNet::/__main__.Up::up4 # /tmp/ipykernel_19565/1651252862.py:39:0\n",
|
270 |
+
" %/up4/Concat_3_output_0 : Float(*, *, 256, 256, strides=[8388608, 65536, 256, 1], requires_grad=1, device=cpu) = onnx::Concat[axis=1, onnx_name=\"/up4/Concat_3\"](%/inc/conv/conv.5/Relu_output_0, %/up4/Pad_output_0), scope: __main__.UNet::/__main__.Up::up4 # /tmp/ipykernel_19565/1651252862.py:40:0\n",
|
271 |
+
" %/up4/conv/conv/conv.0/Conv_output_0 : Float(*, 64, 256, 256, strides=[4194304, 65536, 256, 1], requires_grad=1, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1], onnx_name=\"/up4/conv/conv/conv.0/Conv\"](%/up4/Concat_3_output_0, %onnx::Conv_238, %onnx::Conv_239), scope: __main__.UNet::/__main__.Up::up4/__main__.DoubleConv::conv/torch.nn.modules.container.Sequential::conv/torch.nn.modules.conv.Conv2d::conv.0 # /home/gautham/.local/lib/python3.10/site-packages/torch/nn/modules/conv.py:458:0\n",
|
272 |
+
" %/up4/conv/conv/conv.2/Relu_output_0 : Float(*, 64, 256, 256, strides=[4194304, 65536, 256, 1], requires_grad=1, device=cpu) = onnx::Relu[onnx_name=\"/up4/conv/conv/conv.2/Relu\"](%/up4/conv/conv/conv.0/Conv_output_0), scope: __main__.UNet::/__main__.Up::up4/__main__.DoubleConv::conv/torch.nn.modules.container.Sequential::conv/torch.nn.modules.activation.ReLU::conv.2 # /home/gautham/.local/lib/python3.10/site-packages/torch/nn/functional.py:1453:0\n",
|
273 |
+
" %/up4/conv/conv/conv.3/Conv_output_0 : Float(*, 64, 256, 256, strides=[4194304, 65536, 256, 1], requires_grad=1, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1], onnx_name=\"/up4/conv/conv/conv.3/Conv\"](%/up4/conv/conv/conv.2/Relu_output_0, %onnx::Conv_241, %onnx::Conv_242), scope: __main__.UNet::/__main__.Up::up4/__main__.DoubleConv::conv/torch.nn.modules.container.Sequential::conv/torch.nn.modules.conv.Conv2d::conv.3 # /home/gautham/.local/lib/python3.10/site-packages/torch/nn/modules/conv.py:458:0\n",
|
274 |
+
" %/up4/conv/conv/conv.5/Relu_output_0 : Float(*, 64, 256, 256, strides=[4194304, 65536, 256, 1], requires_grad=1, device=cpu) = onnx::Relu[onnx_name=\"/up4/conv/conv/conv.5/Relu\"](%/up4/conv/conv/conv.3/Conv_output_0), scope: __main__.UNet::/__main__.Up::up4/__main__.DoubleConv::conv/torch.nn.modules.container.Sequential::conv/torch.nn.modules.activation.ReLU::conv.5 # /home/gautham/.local/lib/python3.10/site-packages/torch/nn/functional.py:1453:0\n",
|
275 |
+
" %/outc/conv/Conv_output_0 : Float(*, 1, 256, 256, strides=[65536, 65536, 256, 1], requires_grad=0, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[1, 1], pads=[0, 0, 0, 0], strides=[1, 1], onnx_name=\"/outc/conv/Conv\"](%/up4/conv/conv/conv.5/Relu_output_0, %outc.conv.weight, %outc.conv.bias), scope: __main__.UNet::/__main__.OutConv::outc/torch.nn.modules.conv.Conv2d::conv # /home/gautham/.local/lib/python3.10/site-packages/torch/nn/modules/conv.py:458:0\n",
|
276 |
+
" %output : Float(*, 1, 256, 256, strides=[65536, 65536, 256, 1], requires_grad=1, device=cpu) = onnx::Sigmoid[onnx_name=\"/outc/sigmoid/Sigmoid\"](%/outc/conv/Conv_output_0), scope: __main__.UNet::/__main__.OutConv::outc/torch.nn.modules.activation.Sigmoid::sigmoid # /home/gautham/.local/lib/python3.10/site-packages/torch/nn/modules/activation.py:294:0\n",
|
277 |
+
" return (%output)\n",
|
278 |
+
"\n"
|
279 |
+
]
|
280 |
+
},
|
281 |
+
{
|
282 |
+
"name": "stderr",
|
283 |
+
"output_type": "stream",
|
284 |
+
"text": [
|
285 |
+
"/home/gautham/.local/lib/python3.10/site-packages/torch/onnx/utils.py:687: UserWarning: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (Triggered internally at ../torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1884.)\n",
|
286 |
+
" _C._jit_pass_onnx_graph_shape_type_inference(\n",
|
287 |
+
"/home/gautham/.local/lib/python3.10/site-packages/torch/onnx/utils.py:1178: UserWarning: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (Triggered internally at ../torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1884.)\n",
|
288 |
+
" _C._jit_pass_onnx_graph_shape_type_inference(\n"
|
289 |
+
]
|
290 |
+
}
|
291 |
+
],
|
292 |
+
"source": [
|
293 |
+
"onnx_path = '../weights/model.onnx'\n",
|
294 |
+
"\n",
|
295 |
+
"torch.onnx.export(model,\n",
|
296 |
+
" dummy_input,\n",
|
297 |
+
" onnx_path,\n",
|
298 |
+
" verbose=True,\n",
|
299 |
+
" input_names = ['input'], # the model's input names\n",
|
300 |
+
" output_names = ['output'], # the model's output names\n",
|
301 |
+
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
|
302 |
+
" 'output' : {0 : 'batch_size'}})"
|
303 |
+
]
|
304 |
+
},
|
305 |
+
{
|
306 |
+
"cell_type": "markdown",
|
307 |
+
"metadata": {},
|
308 |
+
"source": [
|
309 |
+
"## Verifying the ONNX model"
|
310 |
+
]
|
311 |
+
},
|
312 |
+
{
|
313 |
+
"cell_type": "code",
|
314 |
+
"execution_count": null,
|
315 |
+
"metadata": {},
|
316 |
+
"outputs": [],
|
317 |
+
"source": [
|
318 |
+
"import onnx\n",
|
319 |
+
"\n",
|
320 |
+
"onnx_model = onnx.load(onnx_path)\n",
|
321 |
+
"onnx.checker.check_model(onnx_model)"
|
322 |
+
]
|
323 |
+
},
|
324 |
+
{
|
325 |
+
"cell_type": "markdown",
|
326 |
+
"metadata": {},
|
327 |
+
"source": [
|
328 |
+
"## Comparing ONNX Runtime and PyTorch results"
|
329 |
+
]
|
330 |
+
},
|
331 |
+
{
|
332 |
+
"cell_type": "code",
|
333 |
+
"execution_count": null,
|
334 |
+
"metadata": {},
|
335 |
+
"outputs": [
|
336 |
+
{
|
337 |
+
"name": "stdout",
|
338 |
+
"output_type": "stream",
|
339 |
+
"text": [
|
340 |
+
"Exported model has been tested with ONNXRuntime, and the result looks good!\n"
|
341 |
+
]
|
342 |
+
}
|
343 |
+
],
|
344 |
+
"source": [
|
345 |
+
"import onnxruntime\n",
|
346 |
+
"import numpy as np\n",
|
347 |
+
"\n",
|
348 |
+
"ort_session = onnxruntime.InferenceSession(onnx_path)\n",
|
349 |
+
"\n",
|
350 |
+
"def to_numpy(tensor):\n",
|
351 |
+
" return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()\n",
|
352 |
+
"\n",
|
353 |
+
"ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(dummy_input)}\n",
|
354 |
+
"ort_outs = ort_session.run(None, ort_inputs)\n",
|
355 |
+
"\n",
|
356 |
+
"np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)\n",
|
357 |
+
"\n",
|
358 |
+
"print(\"Exported model has been tested with ONNXRuntime, and the result looks good!\")\n"
|
359 |
+
]
|
360 |
+
}
|
361 |
+
],
|
362 |
+
"metadata": {
|
363 |
+
"kernelspec": {
|
364 |
+
"display_name": "Python 3.10.6 64-bit",
|
365 |
+
"language": "python",
|
366 |
+
"name": "python3"
|
367 |
+
},
|
368 |
+
"language_info": {
|
369 |
+
"codemirror_mode": {
|
370 |
+
"name": "ipython",
|
371 |
+
"version": 3
|
372 |
+
},
|
373 |
+
"file_extension": ".py",
|
374 |
+
"mimetype": "text/x-python",
|
375 |
+
"name": "python",
|
376 |
+
"nbconvert_exporter": "python",
|
377 |
+
"pygments_lexer": "ipython3",
|
378 |
+
"version": "3.10.6"
|
379 |
+
},
|
380 |
+
"orig_nbformat": 4,
|
381 |
+
"vscode": {
|
382 |
+
"interpreter": {
|
383 |
+
"hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1"
|
384 |
+
}
|
385 |
+
}
|
386 |
+
},
|
387 |
+
"nbformat": 4,
|
388 |
+
"nbformat_minor": 2
|
389 |
+
}
|