etemkocaaslan commited on
Commit
c957273
1 Parent(s): 2ee95b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +159 -100
app.py CHANGED
@@ -7,6 +7,7 @@ import requests
7
  import random
8
  import gradio as gr
9
 
 
10
  image_prediction_models = {
11
  'resnet': models.resnet50,
12
  'alexnet': models.alexnet,
@@ -26,103 +27,161 @@ image_prediction_models = {
26
  'convnext': models.convnext_tiny
27
  }
28
 
29
- def load_pretrained_model(model_name):
30
- model_name_lower = model_name.lower()
31
- if model_name_lower in image_prediction_models:
32
- model_class = image_prediction_models[model_name_lower]
33
- model = model_class(pretrained=True)
34
- return model
35
- else:
36
- raise ValueError(f"Model {model_name} is not available for image prediction in torchvision.models")
37
-
38
- def get_model_names(models_dict):
39
- return [name.capitalize() for name in models_dict.keys()]
40
-
41
- model_list = get_model_names(image_prediction_models)
42
-
43
- normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
44
-
45
- def preprocess(model_name):
46
- input_size = 224
47
- if model_name == 'inception':
48
- input_size = 299
49
- return transforms.Compose([
50
- transforms.Resize(256),
51
- transforms.CenterCrop(input_size),
52
- transforms.ToTensor(),
53
- normalize,
54
- ])
55
-
56
- response = requests.get("https://git.io/JJkYN")
57
- labels = response.text.split("\n")
58
-
59
- def postprocess_default(output):
60
- probabilities = torch.nn.functional.softmax(output[0], dim=0)
61
- top_prob, top_catid = torch.topk(probabilities, 5)
62
- confidences = {labels[top_catid[i].item()]: top_prob[i].item() for i in range(top_prob.size(0))}
63
- return confidences
64
-
65
- def postprocess_inception(output):
66
- probabilities = torch.nn.functional.softmax(output[1], dim=0)
67
- top_prob, top_catid = torch.topk(probabilities, 5)
68
- confidences = {labels[top_catid[i].item()]: top_prob[i].item() for i in range(top_prob.size(0))}
69
- return confidences
70
-
71
- def classify_image(input_image, selected_model):
72
- preprocess_input = preprocess(model_name=selected_model)
73
- input_tensor = preprocess_input(input_image)
74
- input_batch = input_tensor.unsqueeze(0)
75
- model = load_pretrained_model(selected_model)
76
-
77
- if torch.cuda.is_available():
78
- input_batch = input_batch.to('cuda')
79
- model.to('cuda')
80
-
81
- model.eval()
82
- with torch.no_grad():
83
- output = model(input_batch)
84
-
85
- if selected_model.lower() == 'inception':
86
- return postprocess_inception(output)
87
- else:
88
- return postprocess_default(output)
89
-
90
- def get_random_image():
91
- cifar10 = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())
92
- random_idx = random.randint(0, len(cifar10) - 1)
93
- image, _ = cifar10[random_idx]
94
- image = transforms.ToPILImage()(image)
95
- return image
96
-
97
- def generate_random_image():
98
- image = get_random_image()
99
- return image
100
-
101
- def classify_generated_image(image, model):
102
- return classify_image(image, model)
103
-
104
- with gr.Blocks() as demo:
105
- with gr.Tabs():
106
- with gr.TabItem("Upload Image"):
107
- with gr.Row():
108
- with gr.Column():
109
- upload_image = gr.Image(type='pil', label="Upload Image")
110
- model_dropdown_upload = gr.Dropdown(model_list, label="Select Model")
111
- classify_button_upload = gr.Button("Classify")
112
- with gr.Column():
113
- output_label_upload = gr.Label(num_top_classes=5)
114
- classify_button_upload.click(classify_image, inputs=[upload_image, model_dropdown_upload], outputs=output_label_upload)
115
-
116
- with gr.TabItem("Generate Random Image"):
117
- with gr.Row():
118
- with gr.Column():
119
- generate_button = gr.Button("Generate Random Image")
120
- random_image_output = gr.Image(type='pil', label="Random CIFAR-10 Image")
121
- with gr.Column():
122
- model_dropdown_random = gr.Dropdown(model_list, label="Select Model")
123
- classify_button_random = gr.Button("Classify")
124
- output_label_random = gr.Label(num_top_classes=5)
125
- generate_button.click(generate_random_image, inputs=[], outputs=random_image_output)
126
- classify_button_random.click(classify_generated_image, inputs=[random_image_output, model_dropdown_random], outputs=output_label_random)
127
-
128
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import random
8
  import gradio as gr
9
 
10
+ # Predefined models available in torchvision
11
  image_prediction_models = {
12
  'resnet': models.resnet50,
13
  'alexnet': models.alexnet,
 
27
  'convnext': models.convnext_tiny
28
  }
29
 
30
+ # Load a pretrained model from torchvision
31
+ class ModelLoader:
32
+ def __init__(self, model_dict):
33
+ self.model_dict = model_dict
34
+
35
+ def load_model(self, model_name):
36
+ model_name_lower = model_name.lower()
37
+ if model_name_lower in self.model_dict:
38
+ model_class = self.model_dict[model_name_lower]
39
+ model = model_class(pretrained=True)
40
+ return model
41
+ else:
42
+ raise ValueError(f"Model {model_name} is not available for image prediction in torchvision.models")
43
+
44
+ def get_model_names(self):
45
+ return [name.capitalize() for name in self.model_dict.keys()]
46
+
47
+ # Preprocessor: Prepares image for model input
48
+ class Preprocessor:
49
+ def __init__(self):
50
+ self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
51
+
52
+ def preprocess(self, model_name):
53
+ input_size = 224
54
+ if model_name == 'inception':
55
+ input_size = 299
56
+ return transforms.Compose([
57
+ transforms.Resize(256),
58
+ transforms.CenterCrop(input_size),
59
+ transforms.ToTensor(),
60
+ self.normalize,
61
+ ])
62
+
63
+ # Postprocessor: Processes model output
64
+ class Postprocessor:
65
+ def __init__(self, labels):
66
+ self.labels = labels
67
+
68
+ def postprocess_default(self, output):
69
+ probabilities = torch.nn.functional.softmax(output[0], dim=0)
70
+ top_prob, top_catid = torch.topk(probabilities, 5)
71
+ confidences = {self.labels[top_catid[i].item()]: top_prob[i].item() for i in range(top_prob.size(0))}
72
+ return confidences
73
+
74
+ def postprocess_inception(self, output):
75
+ probabilities = torch.nn.functional.softmax(output[1], dim=0)
76
+ top_prob, top_catid = torch.topk(probabilities, 5)
77
+ confidences = {self.labels[top_catid[i].item()]: top_prob[i].item() for i in range(top_prob.size(0))}
78
+ return confidences
79
+
80
+ # ImageClassifier: Classifies images using a selected model
81
+ class ImageClassifier:
82
+ def __init__(self, model_loader, preprocessor, postprocessor):
83
+ self.model_loader = model_loader
84
+ self.preprocessor = preprocessor
85
+ self.postprocessor = postprocessor
86
+
87
+ def classify(self, input_image, selected_model):
88
+ preprocess_input = self.preprocessor.preprocess(model_name=selected_model)
89
+ input_tensor = preprocess_input(input_image)
90
+ input_batch = input_tensor.unsqueeze(0)
91
+ model = self.model_loader.load_model(selected_model)
92
+
93
+ if torch.cuda.is_available():
94
+ input_batch = input_batch.to('cuda')
95
+ model.to('cuda')
96
+
97
+ model.eval()
98
+ with torch.no_grad():
99
+ output = model(input_batch)
100
+
101
+ if selected_model.lower() == 'inception':
102
+ return self.postprocessor.postprocess_inception(output)
103
+ else:
104
+ return self.postprocessor.postprocess_default(output)
105
+
106
+ # CIFAR10ImageProvider: Provides random images from CIFAR-10 dataset
107
+ class CIFAR10ImageProvider:
108
+ def __init__(self, dataset_root='./data'):
109
+ self.dataset_root = dataset_root
110
+
111
+ def get_random_image(self):
112
+ cifar10 = datasets.CIFAR10(root=self.dataset_root, train=False, download=True, transform=transforms.ToTensor())
113
+ random_idx = random.randint(0, len(cifar10) - 1)
114
+ image, _ = cifar10[random_idx]
115
+ image = transforms.ToPILImage()(image)
116
+ return image
117
+
118
+ # GradioApp: Sets up the Gradio interface
119
+ class GradioApp:
120
+ def __init__(self, image_classifier, image_provider, model_list):
121
+ self.image_classifier = image_classifier
122
+ self.image_provider = image_provider
123
+ self.model_list = model_list
124
+
125
+ def launch(self):
126
+ with gr.Blocks() as demo:
127
+ with gr.Tabs():
128
+ with gr.TabItem("Upload Image"):
129
+ with gr.Row():
130
+ with gr.Column():
131
+ upload_image = gr.Image(type='pil', label="Upload Image")
132
+ model_dropdown_upload = gr.Dropdown(self.model_list, label="Select Model")
133
+ classify_button_upload = gr.Button("Classify")
134
+ with gr.Column():
135
+ output_label_upload = gr.Label(num_top_classes=5)
136
+ classify_button_upload.click(self.image_classifier.classify, inputs=[upload_image, model_dropdown_upload], outputs=output_label_upload)
137
+
138
+ with gr.TabItem("Generate Random Image"):
139
+ with gr.Row():
140
+ with gr.Column():
141
+ generate_button = gr.Button("Generate Random Image")
142
+ random_image_output = gr.Image(type='pil', label="Random CIFAR-10 Image")
143
+ with gr.Column():
144
+ model_dropdown_random = gr.Dropdown(self.model_list, label="Select Model")
145
+ classify_button_random = gr.Button("Classify")
146
+ output_label_random = gr.Label(num_top_classes=5)
147
+ generate_button.click(self.image_provider.get_random_image, inputs=[], outputs=random_image_output)
148
+ classify_button_random.click(self.image_classifier.classify, inputs=[random_image_output, model_dropdown_random], outputs=output_label_random)
149
+
150
+ demo.launch()
151
+
152
+ # Main Execution
153
+ if __name__ == "__main__":
154
+ # Define available models
155
+ image_prediction_models = {
156
+ 'resnet': models.resnet50,
157
+ 'alexnet': models.alexnet,
158
+ 'vgg': models.vgg16,
159
+ 'squeezenet': models.squeezenet1_0,
160
+ 'densenet': models.densenet161,
161
+ 'inception': models.inception_v3,
162
+ 'googlenet': models.googlenet,
163
+ 'shufflenet': models.shufflenet_v2_x1_0,
164
+ 'mobilenet': models.mobilenet_v2,
165
+ 'resnext': models.resnext50_32x4d,
166
+ 'wide_resnet': models.wide_resnet50_2,
167
+ 'mnasnet': models.mnasnet1_0,
168
+ 'efficientnet': models.efficientnet_b0,
169
+ 'regnet': models.regnet_y_400mf,
170
+ 'vit': models.vit_b_16,
171
+ 'convnext': models.convnext_tiny
172
+ }
173
+
174
+ # Initialize components
175
+ model_loader = ModelLoader(image_prediction_models)
176
+ preprocessor = Preprocessor()
177
+ response = requests.get("https://git.io/JJkYN")
178
+ labels = response.text.split("\n")
179
+ postprocessor = Postprocessor(labels)
180
+ image_classifier = ImageClassifier(model_loader, preprocessor, postprocessor)
181
+ image_provider = CIFAR10ImageProvider()
182
+
183
+ model_list = model_loader.get_model_names()
184
+
185
+ # Launch Gradio app
186
+ app = GradioApp(image_classifier, image_provider, model_list)
187
+ app.launch()