Ahmed-El-Sharkawy commited on
Commit
4fdffc1
·
verified ·
1 Parent(s): d4b2eed

Create App.py

Browse files
Files changed (1) hide show
  1. App.py +120 -0
App.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import torch
4
+ import torch.nn as nn
5
+ import torchvision.transforms as transforms
6
+ import torchvision.models as models
7
+ import numpy as np
8
+
9
+ # Set device
10
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
+
12
+ # Load the main classifier (Main_Classifier_best_model.pth)
13
+ main_model = models.resnet18(pretrained=False)
14
+ num_ftrs = main_model.fc.in_features
15
+ main_model.fc = nn.Linear(num_ftrs, 3) # 3 classes: Soda drinks, Clothing, Mobile Phones
16
+ main_model.load_state_dict(torch.load('Main_Classifier_best_model.pth', map_location=device))
17
+ main_model = main_model.to(device)
18
+ main_model.eval()
19
+
20
+ # Define class names for the main classifier based on folder structure
21
+ main_class_names = ['Clothing', 'Mobile Phones', 'Soda drinks']
22
+
23
+ # Sub-classifier models
24
+ def load_soda_drinks_model():
25
+ model = models.resnet18(pretrained=False)
26
+ num_ftrs = model.fc.in_features
27
+ model.fc = nn.Linear(num_ftrs, 3) # 3 classes: Miranda, Pepsi, Seven Up
28
+ model.load_state_dict(torch.load('Soda_drinks_best_model.pth', map_location=device))
29
+ model = model.to(device)
30
+ model.eval()
31
+ return model
32
+
33
+ def load_clothing_model():
34
+ model = models.resnet18(pretrained=False)
35
+ num_ftrs = model.fc.in_features
36
+ model.fc = nn.Linear(num_ftrs, 2) # 2 classes: Pants, T-Shirt
37
+ model.load_state_dict(torch.load('Clothes_best_model.pth', map_location=device))
38
+ model = model.to(device)
39
+ model.eval()
40
+ return model
41
+
42
+ def load_mobile_phones_model():
43
+ model = models.resnet18(pretrained=False)
44
+ num_ftrs = model.fc.in_features
45
+ model.fc = nn.Linear(num_ftrs, 2) # 2 classes: Apple, Samsung
46
+ model.load_state_dict(torch.load('Phone_best_model.pth', map_location=device))
47
+ model = model.to(device)
48
+ model.eval()
49
+ return model
50
+
51
+ def convert_to_rgb(image):
52
+ """
53
+ Converts 'P' mode images with transparency to 'RGBA', and then to 'RGB'.
54
+ This is to avoid transparency issues during model training.
55
+ """
56
+ if image.mode in ('P', 'RGBA'):
57
+ return image.convert('RGB')
58
+ return image
59
+
60
+ # Define preprocessing transformations (same used during training)
61
+ preprocess = transforms.Compose([
62
+ transforms.Lambda(convert_to_rgb),
63
+ transforms.Resize((224, 224)),
64
+ transforms.ToTensor(),
65
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # ImageNet normalization
66
+ ])
67
+
68
+ def classify_image(image):
69
+ # Open the image using PIL
70
+ image = Image.fromarray(image)
71
+
72
+ # Preprocess the image
73
+ input_image = preprocess(image).unsqueeze(0).to(device)
74
+
75
+ # Perform inference with the main classifier
76
+ with torch.no_grad():
77
+ output = main_model(input_image)
78
+ probabilities = torch.nn.functional.softmax(output[0], dim=0)
79
+ confidence, predicted_class = torch.max(probabilities, 0)
80
+
81
+ # Main classifier result
82
+ main_prediction = main_class_names[predicted_class]
83
+ main_confidence = confidence.item()
84
+
85
+ # Load and apply the sub-classifier based on the main classification
86
+ if main_prediction == 'Soda drinks':
87
+ soda_model = load_soda_drinks_model()
88
+ sub_class_names = ['Miranda', 'Pepsi', 'Seven Up']
89
+ with torch.no_grad():
90
+ sub_output = soda_model(input_image)
91
+ elif main_prediction == 'Clothing':
92
+ clothing_model = load_clothing_model()
93
+ sub_class_names = ['Pants', 'T-Shirt']
94
+ with torch.no_grad():
95
+ sub_output = clothing_model(input_image)
96
+ elif main_prediction == 'Mobile Phones':
97
+ phones_model = load_mobile_phones_model()
98
+ sub_class_names = ['Apple', 'Samsung']
99
+ with torch.no_grad():
100
+ sub_output = phones_model(input_image)
101
+
102
+ # Perform inference with the sub-classifier
103
+ sub_probabilities = torch.nn.functional.softmax(sub_output[0], dim=0)
104
+ sub_confidence, sub_predicted_class = torch.max(sub_probabilities, 0)
105
+
106
+ sub_prediction = sub_class_names[sub_predicted_class]
107
+ sub_confidence = sub_confidence.item()
108
+
109
+ return f"Main Predicted Class: {main_prediction} (Confidence: {main_confidence:.4f})", \
110
+ f"Sub Predicted Class: {sub_prediction} (Confidence: {sub_confidence:.4f})"
111
+
112
+
113
+ # Gradio interface
114
+ image_input = gr.inputs.Image(shape=(224, 224), image_mode="RGB")
115
+ output_text = gr.outputs.Textbox()
116
+
117
+ gr.Interface(fn=classify_image, inputs=image_input, outputs=output_text,
118
+ title="Main and Sub-Classifier System",
119
+ description="Upload an image to classify whether it belongs to Clothing, Mobile Phones, or Soda Drinks. Based on the prediction, it will further classify within the subcategory.",
120
+ theme="default").launch()