etemkocaaslan commited on
Commit
890a262
1 Parent(s): c957273

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -68
app.py CHANGED
@@ -1,14 +1,18 @@
1
  import torch
 
2
  import torchvision.models as models
3
  import torchvision.transforms as transforms
4
  import torchvision.datasets as datasets
5
  from torchvision.transforms import Compose
 
 
6
  import requests
 
7
  import random
8
- import gradio as gr
9
 
10
  # Predefined models available in torchvision
11
- image_prediction_models = {
12
  'resnet': models.resnet50,
13
  'alexnet': models.alexnet,
14
  'vgg': models.vgg16,
@@ -29,10 +33,10 @@ image_prediction_models = {
29
 
30
  # Load a pretrained model from torchvision
31
  class ModelLoader:
32
- def __init__(self, model_dict):
33
  self.model_dict = model_dict
34
 
35
- def load_model(self, model_name):
36
  model_name_lower = model_name.lower()
37
  if model_name_lower in self.model_dict:
38
  model_class = self.model_dict[model_name_lower]
@@ -41,7 +45,7 @@ class ModelLoader:
41
  else:
42
  raise ValueError(f"Model {model_name} is not available for image prediction in torchvision.models")
43
 
44
- def get_model_names(self):
45
  return [name.capitalize() for name in self.model_dict.keys()]
46
 
47
  # Preprocessor: Prepares image for model input
@@ -49,7 +53,7 @@ class Preprocessor:
49
  def __init__(self):
50
  self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
51
 
52
- def preprocess(self, model_name):
53
  input_size = 224
54
  if model_name == 'inception':
55
  input_size = 299
@@ -62,31 +66,31 @@ class Preprocessor:
62
 
63
  # Postprocessor: Processes model output
64
  class Postprocessor:
65
- def __init__(self, labels):
66
  self.labels = labels
67
 
68
- def postprocess_default(self, output):
69
- probabilities = torch.nn.functional.softmax(output[0], dim=0)
70
- top_prob, top_catid = torch.topk(probabilities, 5)
71
  confidences = {self.labels[top_catid[i].item()]: top_prob[i].item() for i in range(top_prob.size(0))}
72
  return confidences
73
 
74
- def postprocess_inception(self, output):
75
- probabilities = torch.nn.functional.softmax(output[1], dim=0)
76
  top_prob, top_catid = torch.topk(probabilities, 5)
77
  confidences = {self.labels[top_catid[i].item()]: top_prob[i].item() for i in range(top_prob.size(0))}
78
  return confidences
79
 
80
  # ImageClassifier: Classifies images using a selected model
81
  class ImageClassifier:
82
- def __init__(self, model_loader, preprocessor, postprocessor):
83
  self.model_loader = model_loader
84
  self.preprocessor = preprocessor
85
  self.postprocessor = postprocessor
86
 
87
- def classify(self, input_image, selected_model):
88
- preprocess_input = self.preprocessor.preprocess(model_name=selected_model)
89
- input_tensor = preprocess_input(input_image)
90
  input_batch = input_tensor.unsqueeze(0)
91
  model = self.model_loader.load_model(selected_model)
92
 
@@ -96,7 +100,7 @@ class ImageClassifier:
96
 
97
  model.eval()
98
  with torch.no_grad():
99
- output = model(input_batch)
100
 
101
  if selected_model.lower() == 'inception':
102
  return self.postprocessor.postprocess_inception(output)
@@ -105,83 +109,64 @@ class ImageClassifier:
105
 
106
  # CIFAR10ImageProvider: Provides random images from CIFAR-10 dataset
107
  class CIFAR10ImageProvider:
108
- def __init__(self, dataset_root='./data'):
109
  self.dataset_root = dataset_root
 
110
 
111
- def get_random_image(self):
112
- cifar10 = datasets.CIFAR10(root=self.dataset_root, train=False, download=True, transform=transforms.ToTensor())
113
  random_idx = random.randint(0, len(cifar10) - 1)
114
  image, _ = cifar10[random_idx]
115
- image = transforms.ToPILImage()(image)
 
116
  return image
117
 
118
- # GradioApp: Sets up the Gradio interface
119
  class GradioApp:
120
- def __init__(self, image_classifier, image_provider, model_list):
121
  self.image_classifier = image_classifier
122
  self.image_provider = image_provider
123
  self.model_list = model_list
124
 
125
  def launch(self):
126
- with gr.Blocks() as demo:
127
- with gr.Tabs():
128
- with gr.TabItem("Upload Image"):
129
- with gr.Row():
130
- with gr.Column():
131
- upload_image = gr.Image(type='pil', label="Upload Image")
132
- model_dropdown_upload = gr.Dropdown(self.model_list, label="Select Model")
133
- classify_button_upload = gr.Button("Classify")
134
- with gr.Column():
135
- output_label_upload = gr.Label(num_top_classes=5)
136
  classify_button_upload.click(self.image_classifier.classify, inputs=[upload_image, model_dropdown_upload], outputs=output_label_upload)
137
 
138
- with gr.TabItem("Generate Random Image"):
139
- with gr.Row():
140
- with gr.Column():
141
- generate_button = gr.Button("Generate Random Image")
142
- random_image_output = gr.Image(type='pil', label="Random CIFAR-10 Image")
143
- with gr.Column():
144
- model_dropdown_random = gr.Dropdown(self.model_list, label="Select Model")
145
- classify_button_random = gr.Button("Classify")
146
- output_label_random = gr.Label(num_top_classes=5)
147
  generate_button.click(self.image_provider.get_random_image, inputs=[], outputs=random_image_output)
148
  classify_button_random.click(self.image_classifier.classify, inputs=[random_image_output, model_dropdown_random], outputs=output_label_random)
149
 
150
  demo.launch()
151
 
152
- # Main Execution
153
  if __name__ == "__main__":
154
- # Define available models
155
- image_prediction_models = {
156
- 'resnet': models.resnet50,
157
- 'alexnet': models.alexnet,
158
- 'vgg': models.vgg16,
159
- 'squeezenet': models.squeezenet1_0,
160
- 'densenet': models.densenet161,
161
- 'inception': models.inception_v3,
162
- 'googlenet': models.googlenet,
163
- 'shufflenet': models.shufflenet_v2_x1_0,
164
- 'mobilenet': models.mobilenet_v2,
165
- 'resnext': models.resnext50_32x4d,
166
- 'wide_resnet': models.wide_resnet50_2,
167
- 'mnasnet': models.mnasnet1_0,
168
- 'efficientnet': models.efficientnet_b0,
169
- 'regnet': models.regnet_y_400mf,
170
- 'vit': models.vit_b_16,
171
- 'convnext': models.convnext_tiny
172
- }
173
-
174
- # Initialize components
175
- model_loader = ModelLoader(image_prediction_models)
176
  preprocessor = Preprocessor()
177
  response = requests.get("https://git.io/JJkYN")
178
  labels = response.text.split("\n")
179
  postprocessor = Postprocessor(labels)
180
  image_classifier = ImageClassifier(model_loader, preprocessor, postprocessor)
181
  image_provider = CIFAR10ImageProvider()
182
-
183
  model_list = model_loader.get_model_names()
184
 
185
- # Launch Gradio app
186
  app = GradioApp(image_classifier, image_provider, model_list)
187
- app.launch()
 
1
  import torch
2
+ from torch import Tensor as T
3
  import torchvision.models as models
4
  import torchvision.transforms as transforms
5
  import torchvision.datasets as datasets
6
  from torchvision.transforms import Compose
7
+ from torch.nn import Module
8
+ from torch.nn.functional import softmax
9
  import requests
10
+ from PIL import Image
11
  import random
12
+ from gradio import Blocks, Tabs, TabItem, Row, Column, Image, Dropdown, Button, Label
13
 
14
  # Predefined models available in torchvision
15
+ IMAGE_PREDICTION_MODELS = {
16
  'resnet': models.resnet50,
17
  'alexnet': models.alexnet,
18
  'vgg': models.vgg16,
 
33
 
34
  # Load a pretrained model from torchvision
35
  class ModelLoader:
36
+ def __init__(self, model_dict : dict):
37
  self.model_dict = model_dict
38
 
39
+ def load_model(self, model_name : str) -> Module :
40
  model_name_lower = model_name.lower()
41
  if model_name_lower in self.model_dict:
42
  model_class = self.model_dict[model_name_lower]
 
45
  else:
46
  raise ValueError(f"Model {model_name} is not available for image prediction in torchvision.models")
47
 
48
+ def get_model_names(self) -> list:
49
  return [name.capitalize() for name in self.model_dict.keys()]
50
 
51
  # Preprocessor: Prepares image for model input
 
53
  def __init__(self):
54
  self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
55
 
56
+ def preprocess(self, model_name : str) -> Compose:
57
  input_size = 224
58
  if model_name == 'inception':
59
  input_size = 299
 
66
 
67
  # Postprocessor: Processes model output
68
  class Postprocessor:
69
+ def __init__(self, labels : list):
70
  self.labels = labels
71
 
72
+ def postprocess_default(self, output) -> dict:
73
+ probabilities = softmax(output[0], dim=0)
74
+ top_prob , top_catid = torch.topk(probabilities, 5)
75
  confidences = {self.labels[top_catid[i].item()]: top_prob[i].item() for i in range(top_prob.size(0))}
76
  return confidences
77
 
78
+ def postprocess_inception(self, output) -> dict:
79
+ probabilities : T = softmax(output[1], dim=0)
80
  top_prob, top_catid = torch.topk(probabilities, 5)
81
  confidences = {self.labels[top_catid[i].item()]: top_prob[i].item() for i in range(top_prob.size(0))}
82
  return confidences
83
 
84
  # ImageClassifier: Classifies images using a selected model
85
  class ImageClassifier:
86
+ def __init__(self, model_loader : ModelLoader, preprocessor: Preprocessor, postprocessor : Postprocessor):
87
  self.model_loader = model_loader
88
  self.preprocessor = preprocessor
89
  self.postprocessor = postprocessor
90
 
91
+ def classify(self, input_image : Image, selected_model : str) -> dict:
92
+ preprocess_input : Compose = self.preprocessor.preprocess(model_name=selected_model)
93
+ input_tensor : T = preprocess_input(input_image)
94
  input_batch = input_tensor.unsqueeze(0)
95
  model = self.model_loader.load_model(selected_model)
96
 
 
100
 
101
  model.eval()
102
  with torch.no_grad():
103
+ output : T = model(input_batch)
104
 
105
  if selected_model.lower() == 'inception':
106
  return self.postprocessor.postprocess_inception(output)
 
109
 
110
  # CIFAR10ImageProvider: Provides random images from CIFAR-10 dataset
111
  class CIFAR10ImageProvider:
112
+ def __init__(self, dataset_root='./data', transform = transforms.ToTensor()):
113
  self.dataset_root = dataset_root
114
+ self.transform = transform
115
 
116
+ def get_random_image(self, resize_dim=(256, 256)) -> Image:
117
+ cifar10 = datasets.CIFAR10(root=self.dataset_root, train=False, download=True, transform= self.transform)
118
  random_idx = random.randint(0, len(cifar10) - 1)
119
  image, _ = cifar10[random_idx]
120
+ image= transforms.ToPILImage()(image) #bak buraya
121
+ image = image.resize(resize_dim, )
122
  return image
123
 
124
+ # Interface
125
  class GradioApp:
126
+ def __init__(self, image_classifier : ImageClassifier, image_provider : CIFAR10ImageProvider, model_list : list):
127
  self.image_classifier = image_classifier
128
  self.image_provider = image_provider
129
  self.model_list = model_list
130
 
131
  def launch(self):
132
+ with Blocks() as demo:
133
+ with Tabs():
134
+ with TabItem("Upload Image"):
135
+ with Row():
136
+ with Column():
137
+ upload_image = Image(type='pil', label="Upload Image")
138
+ model_dropdown_upload = Dropdown(self.model_list, label="Select Model")
139
+ classify_button_upload = Button("Classify")
140
+ with Column():
141
+ output_label_upload = Label(num_top_classes=5)
142
  classify_button_upload.click(self.image_classifier.classify, inputs=[upload_image, model_dropdown_upload], outputs=output_label_upload)
143
 
144
+ with TabItem("Generate Random Image"):
145
+ with Row():
146
+ with Column():
147
+ generate_button = Button("Generate Random Image")
148
+ random_image_output = Image(type='pil', label="Random CIFAR-10 Image")
149
+ with Column():
150
+ model_dropdown_random = Dropdown(self.model_list, label="Select Model")
151
+ classify_button_random = Button("Classify")
152
+ output_label_random = Label(num_top_classes=5)
153
  generate_button.click(self.image_provider.get_random_image, inputs=[], outputs=random_image_output)
154
  classify_button_random.click(self.image_classifier.classify, inputs=[random_image_output, model_dropdown_random], outputs=output_label_random)
155
 
156
  demo.launch()
157
 
158
+ # Main
159
  if __name__ == "__main__":
160
+ # Initialize
161
+ model_loader = ModelLoader(IMAGE_PREDICTION_MODELS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  preprocessor = Preprocessor()
163
  response = requests.get("https://git.io/JJkYN")
164
  labels = response.text.split("\n")
165
  postprocessor = Postprocessor(labels)
166
  image_classifier = ImageClassifier(model_loader, preprocessor, postprocessor)
167
  image_provider = CIFAR10ImageProvider()
 
168
  model_list = model_loader.get_model_names()
169
 
170
+ # Launch
171
  app = GradioApp(image_classifier, image_provider, model_list)
172
+ app.launch()