Rick458 commited on
Commit
5b59271
1 Parent(s): 6de5ccf

deleted unnecessary methods

Browse files
Files changed (1) hide show
  1. app.py +7 -49
app.py CHANGED
@@ -10,42 +10,7 @@ from PIL import Image
10
 
11
 
12
 
13
-
14
-
15
- # Moving both Data and Model into GPU
16
-
17
- def get_default_device():
18
- """Pick GPU if available, else CPU"""
19
- if torch.cuda.is_available():
20
- return torch.device('cuda')
21
- else:
22
- return torch.device('cpu')
23
-
24
- def to_device(data, device):
25
- """Move tensor(s) to chosen device"""
26
- if isinstance(data, (list,tuple)):
27
- return [to_device(x, device) for x in data]
28
- return data.to(device, non_blocking=True)
29
-
30
- class DeviceDataLoader():
31
- """Wrap a dataloader to move data to a device"""
32
- def __init__(self, dl, device):
33
- self.dl = dl
34
- self.device = device
35
-
36
- def __iter__(self):
37
- """Yield a batch of data after moving it to device"""
38
- for b in self.dl:
39
- yield to_device(b, self.device)
40
-
41
- def __len__(self):
42
- """Number of batches"""
43
- return len(self.dl)
44
-
45
-
46
-
47
  # Defining our Class for just prediction
48
-
49
  def accuracy(outputs, labels):
50
  _, preds = torch.max(outputs, dim=1)
51
  return torch.tensor(torch.sum(preds == labels).item() / len(preds))
@@ -69,13 +34,10 @@ class ImageClassificationBase(nn.Module):
69
 
70
 
71
  # Defining our finetuned Resnet50 Architecture with our Classification layer
72
-
73
  class IndianFoodModelResnet50(ImageClassificationBase):
74
  def __init__(self, num_classes, pretrained=True):
75
  super().__init__()
76
- # Use a pretrained model
77
  self.network = models.resnet50(pretrained=pretrained)
78
- # Replace last layer
79
  self.network.fc = nn.Linear(self.network.fc.in_features, num_classes)
80
 
81
  def forward(self, xb):
@@ -83,7 +45,7 @@ class IndianFoodModelResnet50(ImageClassificationBase):
83
 
84
 
85
 
86
- # for prediction
87
  @torch.no_grad()
88
  def evaluate(model, val_loader):
89
  model.eval()
@@ -92,7 +54,7 @@ def evaluate(model, val_loader):
92
 
93
 
94
 
95
- # initialising our model and moving it to GPU
96
  classes = ['burger', 'butter_naan', 'chai', 'chapati', 'chole_bhature',
97
  'dal_makhani', 'dhokla', 'fried_rice', 'idli', 'jalebi',
98
  'kaathi_rolls', 'kadai_paneer', 'kulfi', 'masala_dosa', 'momos',
@@ -103,40 +65,37 @@ to_device(model, device);
103
 
104
 
105
 
106
- # loading the model
107
  ckp_path = 'indianFood-resnet50.pth'
108
  model.load_state_dict(torch.load(ckp_path, map_location=torch.device('cpu')))
109
  model.eval()
110
 
111
 
112
 
113
- # image preprocessing before prediction
114
  stats = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
115
  img_tfms = tt.Compose([tt.Resize((224, 224)),
116
  tt.ToTensor(),
117
  tt.Normalize(*stats, inplace = True)])
118
 
119
  def predict_image(image, model):
120
- # Convert to a batch of 1
121
  xb = to_device(image.unsqueeze(0), device)
122
- # Get predictions from model
123
  yb = model(xb)
124
- # Pick index with highest probability
125
  _, preds = torch.max(yb, dim=1)
126
- # Retrieve the class label
127
  return classes[preds[0].item()]
128
 
129
 
130
 
 
131
  def classify_image(path):
132
  img = Image.open(path)
133
  img = img_tfms(img)
134
- #img = img.permute(2, 0, 1)
135
  label = predict_image(img, model)
136
-
137
  return label
138
 
139
 
 
 
140
  image = gr.inputs.Image(shape=(224, 224), type="filepath")
141
  label = gr.outputs.Label(num_top_classes=1)
142
 
@@ -149,7 +108,6 @@ gr.Interface(
149
  outputs=label,
150
  examples = [["idli.jpg"], ["naan.jpg"]],
151
  theme = "huggingface",
152
- layout = "horizontal",
153
  title = "DesiVisionNet: Desi Food Vision with ResNet",
154
  description = "This is a Gradio demo for multi-class image classification of Indian food amongst 20 classes. The DesiVisionNet achieved 90% accuracy on our test dataset, performing well for a relatively efficient model. See the GitHub project page for detailed information below. Here, we provide a demo for real-world food classification. To use it, simply upload your image, or click one of the examples to load them.",
155
  article = article
 
10
 
11
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  # Defining our Class for just prediction
 
14
  def accuracy(outputs, labels):
15
  _, preds = torch.max(outputs, dim=1)
16
  return torch.tensor(torch.sum(preds == labels).item() / len(preds))
 
34
 
35
 
36
  # Defining our finetuned Resnet50 Architecture with our Classification layer
 
37
  class IndianFoodModelResnet50(ImageClassificationBase):
38
  def __init__(self, num_classes, pretrained=True):
39
  super().__init__()
 
40
  self.network = models.resnet50(pretrained=pretrained)
 
41
  self.network.fc = nn.Linear(self.network.fc.in_features, num_classes)
42
 
43
  def forward(self, xb):
 
45
 
46
 
47
 
48
+ # Prediction method
49
  @torch.no_grad()
50
  def evaluate(model, val_loader):
51
  model.eval()
 
54
 
55
 
56
 
57
+ # Initialising our model and moving it to CPU
58
  classes = ['burger', 'butter_naan', 'chai', 'chapati', 'chole_bhature',
59
  'dal_makhani', 'dhokla', 'fried_rice', 'idli', 'jalebi',
60
  'kaathi_rolls', 'kadai_paneer', 'kulfi', 'masala_dosa', 'momos',
 
65
 
66
 
67
 
68
+ # Loading the model
69
  ckp_path = 'indianFood-resnet50.pth'
70
  model.load_state_dict(torch.load(ckp_path, map_location=torch.device('cpu')))
71
  model.eval()
72
 
73
 
74
 
75
+ # Image preprocessing before prediction
76
  stats = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
77
  img_tfms = tt.Compose([tt.Resize((224, 224)),
78
  tt.ToTensor(),
79
  tt.Normalize(*stats, inplace = True)])
80
 
81
  def predict_image(image, model):
 
82
  xb = to_device(image.unsqueeze(0), device)
 
83
  yb = model(xb)
 
84
  _, preds = torch.max(yb, dim=1)
 
85
  return classes[preds[0].item()]
86
 
87
 
88
 
89
+ # Function handling input, processing and output
90
  def classify_image(path):
91
  img = Image.open(path)
92
  img = img_tfms(img)
 
93
  label = predict_image(img, model)
 
94
  return label
95
 
96
 
97
+
98
+ # Defining gradio interface functions
99
  image = gr.inputs.Image(shape=(224, 224), type="filepath")
100
  label = gr.outputs.Label(num_top_classes=1)
101
 
 
108
  outputs=label,
109
  examples = [["idli.jpg"], ["naan.jpg"]],
110
  theme = "huggingface",
 
111
  title = "DesiVisionNet: Desi Food Vision with ResNet",
112
  description = "This is a Gradio demo for multi-class image classification of Indian food amongst 20 classes. The DesiVisionNet achieved 90% accuracy on our test dataset, performing well for a relatively efficient model. See the GitHub project page for detailed information below. Here, we provide a demo for real-world food classification. To use it, simply upload your image, or click one of the examples to load them.",
113
  article = article