vvd2003 commited on
Commit
6fa8c33
1 Parent(s): 2cd132f

Upload 15 files

Browse files
__pycache__/class_names.cpython-39.pyc ADDED
Binary file (8.59 kB). View file
 
__pycache__/model.cpython-39.pyc ADDED
Binary file (5.55 kB). View file
 
app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### 1. Imports and class names setup ###
2
+ import gradio as gr
3
+ import os
4
+ import torch
5
+
6
+ from class_names import class_names
7
+ from model import Load_model
8
+ from timeit import default_timer as timer
9
+ from typing import Tuple, Dict
10
+
11
+
12
+ ### 1. Model and transforms preparation ###
13
+
14
+ # Create model and transform
15
+ model, transforms = Load_model()
16
+
17
+ # Load saved weights
18
+ def load_checkpoint(checkpoint_file, model, device='cpu'):
19
+ print("=> Loading checkpoint")
20
+ checkpoint = torch.load(checkpoint_file, map_location=device)
21
+ model.load_state_dict(checkpoint["state_dict"])
22
+ load_checkpoint('model_checkpoint.pt', model)
23
+
24
+ ### 2. Predict function ###
25
+
26
+ # Create predict function
27
+ def predict(img) -> Tuple[Dict, float]:
28
+ """Transforms and performs a prediction on img and returns prediction and time taken.
29
+ """
30
+ # Start the timer
31
+ start_time = timer()
32
+
33
+ # Transform the target image and add a batch dimension
34
+ img = transforms(img).unsqueeze(0)
35
+
36
+ # Put model into evaluation mode and turn on inference mode
37
+ model.eval()
38
+ with torch.inference_mode():
39
+ # Pass the transformed image through the model and turn the prediction logits into prediction probabilities
40
+ pred_probs = torch.softmax(model(img), dim=1)
41
+
42
+ # Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
43
+ pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
44
+
45
+ # Calculate the prediction time
46
+ pred_time = round(timer() - start_time, 5)
47
+
48
+ # Return the prediction dictionary and prediction time
49
+ return pred_labels_and_probs, pred_time
50
+
51
+
52
+ ### 3. Gradio app ###
53
+
54
+ # Create title, description and article strings
55
+ title = "BirdVision 500 🦅🦆🐦🕊🦤🦢🦜"
56
+ description = "A model based on YoLov8 classification 500 birds."
57
+ article = "Created on [GITHUB](https://github.com/vvduc1803?tab=repositories/)."
58
+
59
+ # Create examples list from "examples/" directory
60
+ example_list = [["examples/" + example] for example in os.listdir("examples")]
61
+
62
+ # Create the Gradio demo
63
+ demo = gr.Interface(fn=predict, # mapping function from input to output
64
+ inputs=gr.Image(type="pil"), # what are the inputs?
65
+ outputs=[gr.Label(num_top_classes=10, label="Predictions"), # what are the outputs?
66
+ gr.Number(label="Prediction time (s)")],
67
+ # our fn has two outputs, therefore we have two outputs
68
+ # Create examples list from "examples/" directory
69
+ examples=example_list,
70
+ title=title,
71
+ description=description,
72
+ article=article)
73
+
74
+ # Launch the demo!
75
+ demo.launch()
class_names.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class_names = ['ABBOTTS BABBLER', 'ABBOTTS BOOBY', 'ABYSSINIAN GROUND HORNBILL', 'AFRICAN CROWNED CRANE', 'AFRICAN '
2
+ 'EMERALD '
3
+ 'CUCKOO',
4
+ 'AFRICAN FIREFINCH', 'AFRICAN OYSTER CATCHER', 'AFRICAN PIED HORNBILL', 'AFRICAN PYGMY GOOSE',
5
+ 'ALBATROSS', 'ALBERTS TOWHEE', 'ALEXANDRINE PARAKEET', 'ALPINE CHOUGH', 'ALTAMIRA YELLOWTHROAT',
6
+ 'AMERICAN AVOCET', 'AMERICAN BITTERN', 'AMERICAN COOT', 'AMERICAN FLAMINGO', 'AMERICAN GOLDFINCH',
7
+ 'AMERICAN KESTREL', 'AMERICAN PIPIT', 'AMERICAN REDSTART', 'AMERICAN ROBIN', 'AMERICAN WIGEON',
8
+ 'AMETHYST WOODSTAR', 'ANDEAN GOOSE', 'ANDEAN LAPWING', 'ANDEAN SISKIN', 'ANHINGA', 'ANIANIAU',
9
+ 'ANNAS HUMMINGBIRD', 'ANTBIRD', 'ANTILLEAN EUPHONIA', 'APAPANE', 'APOSTLEBIRD', 'ARARIPE MANAKIN',
10
+ 'ASHY STORM PETREL', 'ASHY THRUSHBIRD', 'ASIAN CRESTED IBIS', 'ASIAN DOLLARD BIRD', 'AUCKLAND SHAQ',
11
+ 'AUSTRAL CANASTERO', 'AUSTRALASIAN FIGBIRD', 'AVADAVAT', 'AZARAS SPINETAIL', 'AZURE BREASTED PITTA',
12
+ 'AZURE JAY', 'AZURE TANAGER', 'AZURE TIT', 'BAIKAL TEAL', 'BALD EAGLE', 'BALD IBIS', 'BALI STARLING',
13
+ 'BALTIMORE ORIOLE', 'BANANAQUIT', 'BAND TAILED GUAN', 'BANDED BROADBILL', 'BANDED PITA',
14
+ 'BANDED STILT', 'BAR-TAILED GODWIT', 'BARN OWL', 'BARN SWALLOW', 'BARRED PUFFBIRD',
15
+ 'BARROWS GOLDENEYE', 'BAY-BREASTED WARBLER', 'BEARDED BARBET', 'BEARDED BELLBIRD', 'BEARDED REEDLING',
16
+ 'BELTED KINGFISHER', 'BIRD OF PARADISE', 'BLACK AND YELLOW BROADBILL', 'BLACK BAZA', 'BLACK COCKATO',
17
+ 'BLACK FACED SPOONBILL', 'BLACK FRANCOLIN', 'BLACK HEADED CAIQUE', 'BLACK NECKED STILT',
18
+ 'BLACK SKIMMER', 'BLACK SWAN', 'BLACK TAIL CRAKE', 'BLACK THROATED BUSHTIT', 'BLACK THROATED HUET',
19
+ 'BLACK THROATED WARBLER', 'BLACK VENTED SHEARWATER', 'BLACK VULTURE', 'BLACK-CAPPED CHICKADEE',
20
+ 'BLACK-NECKED GREBE', 'BLACK-THROATED SPARROW', 'BLACKBURNIAM WARBLER', 'BLONDE CRESTED WOODPECKER',
21
+ 'BLOOD PHEASANT', 'BLUE COAU', 'BLUE DACNIS', 'BLUE GRAY GNATCATCHER', 'BLUE GROSBEAK', 'BLUE GROUSE',
22
+ 'BLUE HERON', 'BLUE MALKOHA', 'BLUE THROATED TOUCANET', 'BOBOLINK', 'BORNEAN BRISTLEHEAD',
23
+ 'BORNEAN LEAFBIRD', 'BORNEAN PHEASANT', 'BRANDT CORMARANT', 'BREWERS BLACKBIRD', 'BROWN CREPPER',
24
+ 'BROWN HEADED COWBIRD', 'BROWN NOODY', 'BROWN THRASHER', 'BUFFLEHEAD', 'BULWERS PHEASANT', 'BURCHELLS '
25
+ 'COURSER',
26
+ 'BUSH TURKEY', 'CAATINGA CACHOLOTE', 'CACTUS WREN', 'CALIFORNIA CONDOR', 'CALIFORNIA GULL',
27
+ 'CALIFORNIA QUAIL', 'CAMPO FLICKER', 'CANARY', 'CANVASBACK', 'CAPE GLOSSY STARLING', 'CAPE LONGCLAW',
28
+ 'CAPE MAY WARBLER', 'CAPE ROCK THRUSH', 'CAPPED HERON', 'CAPUCHINBIRD', 'CARMINE BEE-EATER',
29
+ 'CASPIAN TERN', 'CASSOWARY', 'CEDAR WAXWING', 'CERULEAN WARBLER', 'CHARA DE COLLAR', 'CHATTERING '
30
+ 'LORY',
31
+ 'CHESTNET BELLIED EUPHONIA', 'CHINESE BAMBOO PARTRIDGE', 'CHINESE POND HERON', 'CHIPPING SPARROW',
32
+ 'CHUCAO TAPACULO', 'CHUKAR PARTRIDGE', 'CINNAMON ATTILA', 'CINNAMON FLYCATCHER', 'CINNAMON TEAL',
33
+ 'CLARKS GREBE', 'CLARKS NUTCRACKER', 'COCK OF THE ROCK', 'COCKATOO', 'COLLARED ARACARI',
34
+ 'COLLARED CRESCENTCHEST', 'COMMON FIRECREST', 'COMMON GRACKLE', 'COMMON HOUSE MARTIN', 'COMMON IORA',
35
+ 'COMMON LOON', 'COMMON POORWILL', 'COMMON STARLING', 'COPPERY TAILED COUCAL', 'CRAB PLOVER',
36
+ 'CRANE HAWK', 'CREAM COLORED WOODPECKER', 'CRESTED AUKLET', 'CRESTED CARACARA', 'CRESTED COUA',
37
+ 'CRESTED FIREBACK', 'CRESTED KINGFISHER', 'CRESTED NUTHATCH', 'CRESTED OROPENDOLA', 'CRESTED SERPENT '
38
+ 'EAGLE',
39
+ 'CRESTED SHRIKETIT', 'CRESTED WOOD PARTRIDGE', 'CRIMSON CHAT', 'CRIMSON SUNBIRD', 'CROW',
40
+ 'CROWNED PIGEON', 'CUBAN TODY', 'CUBAN TROGON', 'CURL CRESTED ARACURI', 'D-ARNAUDS BARBET',
41
+ 'DALMATIAN PELICAN', 'DARJEELING WOODPECKER', 'DARK EYED JUNCO', 'DAURIAN REDSTART', 'DEMOISELLE '
42
+ 'CRANE',
43
+ 'DOUBLE BARRED FINCH', 'DOUBLE BRESTED CORMARANT', 'DOUBLE EYED FIG PARROT', 'DOWNY WOODPECKER',
44
+ 'DUSKY LORY', 'DUSKY ROBIN', 'EARED PITA', 'EASTERN BLUEBIRD', 'EASTERN BLUEBONNET', 'EASTERN GOLDEN '
45
+ 'WEAVER',
46
+ 'EASTERN MEADOWLARK', 'EASTERN ROSELLA', 'EASTERN TOWEE', 'EASTERN WIP POOR WILL', 'EASTERN YELLOW '
47
+ 'ROBIN',
48
+ 'ECUADORIAN HILLSTAR', 'EGYPTIAN GOOSE', 'ELEGANT TROGON', 'ELLIOTS PHEASANT', 'EMERALD TANAGER',
49
+ 'EMPEROR PENGUIN', 'EMU', 'ENGGANO MYNA', 'EURASIAN BULLFINCH', 'EURASIAN GOLDEN ORIOLE',
50
+ 'EURASIAN MAGPIE', 'EUROPEAN GOLDFINCH', 'EUROPEAN TURTLE DOVE', 'EVENING GROSBEAK', 'FAIRY BLUEBIRD',
51
+ 'FAIRY PENGUIN', 'FAIRY TERN', 'FAN TAILED WIDOW', 'FASCIATED WREN', 'FIERY MINIVET', 'FIORDLAND '
52
+ 'PENGUIN',
53
+ 'FIRE TAILLED MYZORNIS', 'FLAME BOWERBIRD', 'FLAME TANAGER', 'FRIGATE', 'FRILL BACK PIGEON',
54
+ 'GAMBELS QUAIL', 'GANG GANG COCKATOO', 'GILA WOODPECKER', 'GILDED FLICKER', 'GLOSSY IBIS',
55
+ 'GO AWAY BIRD', 'GOLD WING WARBLER', 'GOLDEN BOWER BIRD', 'GOLDEN CHEEKED WARBLER',
56
+ 'GOLDEN CHLOROPHONIA', 'GOLDEN EAGLE', 'GOLDEN PARAKEET', 'GOLDEN PHEASANT', 'GOLDEN PIPIT',
57
+ 'GOULDIAN FINCH', 'GRANDALA', 'GRAY CATBIRD', 'GRAY KINGBIRD', 'GRAY PARTRIDGE', 'GREAT ARGUS',
58
+ 'GREAT GRAY OWL', 'GREAT JACAMAR', 'GREAT KISKADEE', 'GREAT POTOO', 'GREAT TINAMOU', 'GREAT XENOPS',
59
+ 'GREATER PEWEE', 'GREATER PRAIRIE CHICKEN', 'GREATOR SAGE GROUSE', 'GREEN BROADBILL', 'GREEN JAY',
60
+ 'GREEN MAGPIE', 'GREEN WINGED DOVE', 'GREY CUCKOOSHRIKE', 'GREY HEADED FISH EAGLE', 'GREY PLOVER',
61
+ 'GROVED BILLED ANI', 'GUINEA TURACO', 'GUINEAFOWL', 'GURNEYS PITTA', 'GYRFALCON', 'HAMERKOP',
62
+ 'HARLEQUIN DUCK', 'HARLEQUIN QUAIL', 'HARPY EAGLE', 'HAWAIIAN GOOSE', 'HAWFINCH', 'HELMET VANGA',
63
+ 'HEPATIC TANAGER', 'HIMALAYAN BLUETAIL', 'HIMALAYAN MONAL', 'HOATZIN', 'HOODED MERGANSER', 'HOOPOES',
64
+ 'HORNED GUAN', 'HORNED LARK', 'HORNED SUNGEM', 'HOUSE FINCH', 'HOUSE SPARROW', 'HYACINTH MACAW',
65
+ 'IBERIAN MAGPIE', 'IBISBILL', 'IMPERIAL SHAQ', 'INCA TERN', 'INDIAN BUSTARD', 'INDIAN PITTA',
66
+ 'INDIAN ROLLER', 'INDIAN VULTURE', 'INDIGO BUNTING', 'INDIGO FLYCATCHER', 'INLAND DOTTEREL',
67
+ 'IVORY BILLED ARACARI', 'IVORY GULL', 'IWI', 'JABIRU', 'JACK SNIPE', 'JACOBIN PIGEON',
68
+ 'JANDAYA PARAKEET', 'JAPANESE ROBIN', 'JAVA SPARROW', 'JOCOTOCO ANTPITTA', 'KAGU', 'KAKAPO',
69
+ 'KILLDEAR', 'KING EIDER', 'KING VULTURE', 'KIWI', 'KOOKABURRA', 'LARK BUNTING', 'LAUGHING GULL',
70
+ 'LAZULI BUNTING', 'LESSER ADJUTANT', 'LILAC ROLLER', 'LIMPKIN', 'LITTLE AUK', 'LOGGERHEAD SHRIKE',
71
+ 'LONG-EARED OWL', 'LOONEY BIRDS', 'LUCIFER HUMMINGBIRD', 'MAGPIE GOOSE', 'MALABAR HORNBILL',
72
+ 'MALACHITE KINGFISHER', 'MALAGASY WHITE EYE', 'MALEO', 'MALLARD DUCK', 'MANDRIN DUCK',
73
+ 'MANGROVE CUCKOO', 'MARABOU STORK', 'MASKED BOBWHITE', 'MASKED BOOBY', 'MASKED LAPWING',
74
+ 'MCKAYS BUNTING', 'MERLIN', 'MIKADO PHEASANT', 'MILITARY MACAW', 'MOURNING DOVE', 'MYNA',
75
+ 'NICOBAR PIGEON', 'NOISY FRIARBIRD', 'NORTHERN BEARDLESS TYRANNULET', 'NORTHERN CARDINAL',
76
+ 'NORTHERN FLICKER', 'NORTHERN FULMAR', 'NORTHERN GANNET', 'NORTHERN GOSHAWK', 'NORTHERN JACANA',
77
+ 'NORTHERN MOCKINGBIRD', 'NORTHERN PARULA', 'NORTHERN RED BISHOP', 'NORTHERN SHOVELER', 'OCELLATED '
78
+ 'TURKEY',
79
+ 'OKINAWA RAIL', 'ORANGE BRESTED BUNTING', 'ORIENTAL BAY OWL', 'ORNATE HAWK EAGLE', 'OSPREY',
80
+ 'OSTRICH', 'OVENBIRD', 'OYSTER CATCHER', 'PAINTED BUNTING', 'PALILA', 'PALM NUT VULTURE',
81
+ 'PARADISE TANAGER', 'PARAKETT AKULET', 'PARUS MAJOR', 'PATAGONIAN SIERRA FINCH', 'PEACOCK',
82
+ 'PEREGRINE FALCON', 'PHAINOPEPLA', 'PHILIPPINE EAGLE', 'PINK ROBIN', 'PLUSH CRESTED JAY',
83
+ 'POMARINE JAEGER', 'PUFFIN', 'PUNA TEAL', 'PURPLE FINCH', 'PURPLE GALLINULE', 'PURPLE MARTIN',
84
+ 'PURPLE SWAMPHEN', 'PYGMY KINGFISHER', 'PYRRHULOXIA', 'QUETZAL', 'RAINBOW LORIKEET', 'RAZORBILL',
85
+ 'RED BEARDED BEE EATER', 'RED BELLIED PITTA', 'RED BILLED TROPICBIRD', 'RED BROWED FINCH', 'RED FACED '
86
+ 'CORMORANT',
87
+ 'RED FACED WARBLER', 'RED FODY', 'RED HEADED DUCK', 'RED HEADED WOODPECKER', 'RED KNOT', 'RED LEGGED '
88
+ 'HONEYCREEPER'
89
+ '',
90
+ 'RED NAPED TROGON', 'RED SHOULDERED HAWK', 'RED TAILED HAWK', 'RED TAILED THRUSH', 'RED WINGED '
91
+ 'BLACKBIRD',
92
+ 'RED WISKERED BULBUL', 'REGENT BOWERBIRD', 'RING-NECKED PHEASANT', 'ROADRUNNER', 'ROCK DOVE',
93
+ 'ROSE BREASTED COCKATOO', 'ROSE BREASTED GROSBEAK', 'ROSEATE SPOONBILL', 'ROSY FACED LOVEBIRD',
94
+ 'ROUGH LEG BUZZARD', 'ROYAL FLYCATCHER', 'RUBY CROWNED KINGLET', 'RUBY THROATED HUMMINGBIRD',
95
+ 'RUDY KINGFISHER', 'RUFOUS KINGFISHER', 'RUFUOS MOTMOT', 'SAMATRAN THRUSH', 'SAND MARTIN',
96
+ 'SANDHILL CRANE', 'SATYR TRAGOPAN', 'SAYS PHOEBE', 'SCARLET CROWNED FRUIT DOVE', 'SCARLET FACED '
97
+ 'LIOCICHLA',
98
+ 'SCARLET IBIS', 'SCARLET MACAW', 'SCARLET TANAGER', 'SHOEBILL', 'SHORT BILLED DOWITCHER',
99
+ 'SMITHS LONGSPUR', 'SNOW GOOSE', 'SNOWY EGRET', 'SNOWY OWL', 'SNOWY PLOVER', 'SORA', 'SPANGLED '
100
+ 'COTINGA',
101
+ 'SPLENDID WREN', 'SPOON BILED SANDPIPER', 'SPOTTED CATBIRD', 'SPOTTED WHISTLING DUCK', 'SRI LANKA BLUE '
102
+ 'MAGPIE',
103
+ 'STEAMER DUCK', 'STORK BILLED KINGFISHER', 'STRIATED CARACARA', 'STRIPED OWL', 'STRIPPED MANAKIN',
104
+ 'STRIPPED SWALLOW', 'SUNBITTERN', 'SUPERB STARLING', 'SURF SCOTER', 'SWINHOES PHEASANT', 'TAILORBIRD',
105
+ 'TAIWAN MAGPIE', 'TAKAHE', 'TASMANIAN HEN', 'TAWNY FROGMOUTH', 'TEAL DUCK', 'TIT MOUSE', 'TOUCHAN',
106
+ 'TOWNSENDS WARBLER', 'TREE SWALLOW', 'TRICOLORED BLACKBIRD', 'TROPICAL KINGBIRD', 'TRUMPTER SWAN',
107
+ 'TURKEY VULTURE', 'TURQUOISE MOTMOT', 'UMBRELLA BIRD', 'VARIED THRUSH', 'VEERY', 'VENEZUELIAN '
108
+ 'TROUPIAL', 'VERDIN',
109
+ 'VERMILION FLYCATHER', 'VICTORIA CROWNED PIGEON', 'VIOLET BACKED STARLING', 'VIOLET GREEN SWALLOW',
110
+ 'VIOLET TURACO', 'VULTURINE GUINEAFOWL', 'WALL CREAPER', 'WATTLED CURASSOW', 'WATTLED LAPWING',
111
+ 'WHIMBREL', 'WHITE BROWED CRAKE', 'WHITE CHEEKED TURACO', 'WHITE CRESTED HORNBILL', 'WHITE EARED '
112
+ 'HUMMINGBIRD',
113
+ 'WHITE NECKED RAVEN', 'WHITE TAILED TROPIC', 'WHITE THROATED BEE EATER', 'WILD TURKEY',
114
+ 'WILLOW PTARMIGAN', 'WILSONS BIRD OF PARADISE', 'WOOD DUCK', 'WOOD THRUSH', 'WRENTIT', 'YELLOW BELLIED '
115
+ 'FLOWERPECKER',
116
+ 'YELLOW CACIQUE', 'YELLOW HEADED BLACKBIRD']
117
+
examples/002.jpg ADDED
examples/026.jpg ADDED
examples/anianiau.jpg ADDED
examples/azure jay.jpg ADDED
examples/banded stilt.jpg ADDED
examples/northern ganner.jpg ADDED
examples/sand martin.jpg ADDED
examples/scalert macaw.jpg ADDED
examples/wall creaper.jpg ADDED
model.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import transforms
4
+
5
+ class CNNBlock(nn.Module):
6
+ """Base block in CNN"""
7
+ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bn_act=True):
8
+ super().__init__()
9
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=not bn_act)
10
+ self.bn = nn.BatchNorm2d(out_channels)
11
+ self.silu = nn.SiLU()
12
+ self.use_bn_act = bn_act
13
+
14
+ def forward(self, x):
15
+ if self.use_bn_act:
16
+ x = self.silu(self.bn(self.conv(x)))
17
+
18
+ return x
19
+ else:
20
+ return self.conv(x)
21
+
22
+ class BottleNeckBlock(nn.Module):
23
+ def __init__(self, channels, short_cut=True):
24
+ super().__init__()
25
+ self.short_cut = short_cut
26
+ self.Conv = nn.Sequential(CNNBlock(channels, channels//2, 3, 1, 1),
27
+ CNNBlock(channels//2, channels, 3, 1, 1))
28
+
29
+ def forward(self, x):
30
+ if self.short_cut:
31
+ return self.Conv(x) + x
32
+ else:
33
+ return self.Conv(x)
34
+
35
+ class C2FBlock(nn.Module):
36
+ def __init__(self, in_channels, out_channels, **kwargs):
37
+ super().__init__()
38
+ self.in_channels = in_channels
39
+ self.out_channels = out_channels
40
+ self.Conv = CNNBlock(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
41
+ self.Conv_end = CNNBlock(int(0.5*(1+2)*out_channels), out_channels, kernel_size=1, stride=1, padding=0)
42
+ self.BottleNeck = BottleNeckBlock(out_channels//2, **kwargs)
43
+
44
+ def forward(self, x):
45
+ x = self.Conv(x)
46
+ x, x1 = torch.split(x, self.out_channels//2, dim=1)
47
+ x2 = self.BottleNeck(x1)
48
+ x = torch.cat([x, x1, x2], dim=1)
49
+ x = self.Conv_end(x)
50
+ return x
51
+
52
+ class C2F_2_Block(nn.Module):
53
+ def __init__(self, in_channels, out_channels, **kwargs):
54
+ super().__init__()
55
+ self.in_channels = in_channels
56
+ self.out_channels = out_channels
57
+ self.Conv = CNNBlock(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
58
+ self.Conv_end = CNNBlock(int(0.5*(2+2)*out_channels), out_channels, kernel_size=1, stride=1, padding=0)
59
+ self.BottleNeck = BottleNeckBlock(out_channels//2, **kwargs)
60
+
61
+ def forward(self, x):
62
+ x = self.Conv(x)
63
+ x, x1 = torch.split(x, self.out_channels//2, dim=1)
64
+ x2 = self.BottleNeck(x1)
65
+ x3 = self.BottleNeck(x2)
66
+ x = torch.cat([x, x1, x2, x3], dim=1)
67
+ x = self.Conv_end(x)
68
+ return x
69
+
70
+ class SPPFBlock(nn.Module):
71
+ def __init__(self, channels):
72
+ super().__init__()
73
+ self.Conv = CNNBlock(channels, channels, kernel_size=1, stride=1, padding=0)
74
+ self.Conv_end = CNNBlock(4*channels, channels, kernel_size=1, stride=1, padding=0)
75
+ self.MaxPool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
76
+
77
+ def forward(self, x):
78
+ x = self.Conv(x)
79
+ x = torch.cat([x, self.MaxPool(x), self.MaxPool(self.MaxPool(x)), self.MaxPool(self.MaxPool(self.MaxPool(x)))],
80
+ dim=1)
81
+ x = self.Conv_end(x)
82
+ return x
83
+
84
+ class Classifier(nn.Module):
85
+ def __init__(self, num_classes=500):
86
+ super().__init__()
87
+ self.Conv = nn.Sequential(CNNBlock(512, 1280, kernel_size=1, stride=1, padding=0))
88
+ self.Flatten = nn.Flatten()
89
+ self.Linear = nn.Sequential(nn.Linear(62720, num_classes))
90
+
91
+ def forward(self, x):
92
+ x = self.Conv(x)
93
+ x = self.Flatten(x)
94
+ x = self.Linear(x)
95
+ return x
96
+
97
+ class Yolov8_cls(nn.Module):
98
+ """Model architecture based page: https://blog.roboflow.com/whats-new-in-yolov8/
99
+ and the ONNX file of yolov8_cls.onnx"""
100
+
101
+ def __init__(self, in_channels, num_classes=500):
102
+ super().__init__()
103
+ self.Block1 = nn.Sequential(CNNBlock(in_channels, 32, 3, 2, 1),
104
+ CNNBlock(32, 64, 3, 2, 1))
105
+
106
+ self.Block2 = C2FBlock(64, 64)
107
+
108
+ self.Block3 = nn.Sequential(CNNBlock(64, 128, 3, 2, 1),
109
+ C2F_2_Block(128, 128))
110
+
111
+ self.Block4 = nn.Sequential(CNNBlock(128, 256, 3, 2, 1),
112
+ C2F_2_Block(256, 256))
113
+
114
+ self.Block5 = nn.Sequential(CNNBlock(256, 512, 3, 2, 1),
115
+ C2F_2_Block(512, 512))
116
+
117
+ self.Block6 = Classifier(num_classes)
118
+
119
+ def forward(self, x):
120
+ x = self.Block1(x)
121
+ x = self.Block2(x)
122
+ x = self.Block3(x)
123
+ x = self.Block4(x)
124
+ x = self.Block5(x)
125
+ x = self.Block6(x)
126
+ return x
127
+
128
+
129
+ def Load_model():
130
+ """Load model and transforms.
131
+ Returns:
132
+ model (torch.nn.Module): EffNetB2 feature extractor model.
133
+ transforms (torchvision.transforms): EffNetB2 image transforms.
134
+ """
135
+ IMAGE_SIZE= 224
136
+
137
+ model = Yolov8_cls(3)
138
+
139
+ transform = transforms.Compose([transforms.Resize(IMAGE_SIZE),
140
+ transforms.CenterCrop(IMAGE_SIZE),
141
+ transforms.ToTensor(),
142
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
143
+
144
+ return model, transform
model_checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3cc784f64ed37e9041760b8877c1e2bbb9ec9e46fdf05db1ed4441e66e6331ed
3
+ size 425175633