Ahmed-El-Sharkawy commited on
Commit
eb179f1
·
verified ·
1 Parent(s): 0b3002e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +156 -0
app.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import torch
4
+ import torch.nn as nn
5
+ import torchvision.transforms as transforms
6
+ import torchvision.models as models
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM
8
+ import os
9
+
10
+ # Set device
11
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
+
13
+ # Load the main classifier (Main_Classifier_best_model.pth)
14
+ main_model = models.resnet18(weights=None) # Updated: weights=None
15
+ num_ftrs = main_model.fc.in_features
16
+ main_model.fc = nn.Linear(num_ftrs, 3) # 3 classes: Soda drinks, Clothing, Mobile Phones
17
+ main_model.load_state_dict(torch.load('Main_Classifier_best_model.pth', map_location=device, weights_only=True)) # Updated: weights_only=True
18
+ main_model = main_model.to(device)
19
+ main_model.eval()
20
+
21
+ # Define class names for the main classifier based on folder structure
22
+ main_class_names = ['Clothing', 'Mobile Phones', 'Soda drinks']
23
+
24
+ # Sub-classifier models
25
+ def load_soda_drinks_model():
26
+ model = models.resnet18(weights=None) # Updated: weights=None
27
+ num_ftrs = model.fc.in_features
28
+ model.fc = nn.Linear(num_ftrs, 3) # 3 classes: Miranda, Pepsi, Seven Up
29
+ model.load_state_dict(torch.load('Soda_drinks_best_model.pth', map_location=device, weights_only=True)) # Updated
30
+ model = model.to(device)
31
+ model.eval()
32
+ return model
33
+
34
+ def load_clothing_model():
35
+ model = models.resnet18(weights=None) # Updated: weights=None
36
+ num_ftrs = model.fc.in_features
37
+ model.fc = nn.Linear(num_ftrs, 3) # 2 classes: Pants, T-Shirt
38
+ model.load_state_dict(torch.load('Clothes_best_model.pth', map_location=device, weights_only=True)) # Updated
39
+ model = model.to(device)
40
+ model.eval()
41
+ return model
42
+
43
+ def load_mobile_phones_model():
44
+ model = models.resnet18(weights=None) # Updated: weights=None
45
+ num_ftrs = model.fc.in_features
46
+ model.fc = nn.Linear(num_ftrs, 2) # 2 classes: Apple, Samsung
47
+ model.load_state_dict(torch.load('Phone_best_model.pth', map_location=device, weights_only=True)) # Updated
48
+ model = model.to(device)
49
+ model.eval()
50
+ return model
51
+
52
+ def convert_to_rgb(image):
53
+ """
54
+ Converts 'P' mode images with transparency to 'RGBA', and then to 'RGB'.
55
+ This is to avoid transparency issues during model training.
56
+ """
57
+ if image.mode in ('P', 'RGBA'):
58
+ return image.convert('RGB')
59
+ return image
60
+
61
+ # Define preprocessing transformations (same used during training)
62
+ preprocess = transforms.Compose([
63
+ transforms.Lambda(convert_to_rgb),
64
+ transforms.Resize((224, 224)), # Resize here, no need for shape argument in gr.Image
65
+ transforms.ToTensor(),
66
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # ImageNet normalization
67
+ ])
68
+
69
+ # Load Meta's LLaMA model for generating product descriptions
70
+ def load_llama():
71
+ model_name = "meta-llama/Llama-3.2-1B-Instruct"
72
+ token = os.getenv("HUGGINGFACE_TOKEN")
73
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=token)
74
+ model = AutoModelForCausalLM.from_pretrained(model_name, use_auth_token=token).to(device)
75
+ return tokenizer, model
76
+
77
+ llama_tokenizer, llama_model = load_llama()
78
+
79
+ # Generate product description using LLaMA
80
+ def generate_description(category, subclass):
81
+ prompt = f"Generate a detailed and engaging product description for a {category} of type {subclass}."
82
+
83
+ inputs = llama_tokenizer.encode(prompt, return_tensors="pt").to(device)
84
+ outputs = llama_model.generate(inputs, max_length=100, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
85
+ description = llama_tokenizer.decode(outputs[0], skip_special_tokens=True)
86
+
87
+ return description
88
+
89
+
90
+ def classify_image(image):
91
+ # Open the image using PIL
92
+ image = Image.fromarray(image)
93
+
94
+ # Preprocess the image
95
+ input_image = preprocess(image).unsqueeze(0).to(device)
96
+
97
+ # Perform inference with the main classifier
98
+ with torch.no_grad():
99
+ output = main_model(input_image)
100
+ probabilities = torch.nn.functional.softmax(output[0], dim=0)
101
+ confidence, predicted_class = torch.max(probabilities, 0)
102
+
103
+ # Main classifier result
104
+ main_prediction = main_class_names[predicted_class]
105
+ main_confidence = confidence.item()
106
+ if main_confidence <=0.90:
107
+ main_prediction = 'Others'
108
+ main_confidence = 100-main_confidence
109
+ sub_prediction = "Undefined"
110
+ sub_confidence = -100
111
+ description = None
112
+ # Load and apply the sub-classifier based on the main classification
113
+ if main_prediction in ['Clothing', 'Mobile Phones', 'Soda drinks']:
114
+ if main_prediction == 'Soda drinks':
115
+ soda_model = load_soda_drinks_model()
116
+ sub_class_names = ['Miranda', 'Pepsi', 'Seven Up']
117
+ with torch.no_grad():
118
+ sub_output = soda_model(input_image)
119
+ elif main_prediction == 'Clothing':
120
+ clothing_model = load_clothing_model()
121
+ sub_class_names = ['Pants', 'T-Shirt','others']
122
+ with torch.no_grad():
123
+ sub_output = clothing_model(input_image)
124
+ elif main_prediction == 'Mobile Phones':
125
+ phones_model = load_mobile_phones_model()
126
+ sub_class_names = ['Apple', 'Samsung']
127
+ with torch.no_grad():
128
+ sub_output = phones_model(input_image)
129
+
130
+ # Perform inference with the sub-classifier
131
+ sub_probabilities = torch.nn.functional.softmax(sub_output[0], dim=0)
132
+ sub_confidence, sub_predicted_class = torch.max(sub_probabilities, 0)
133
+
134
+ sub_prediction = sub_class_names[sub_predicted_class]
135
+ sub_confidence = sub_confidence.item()
136
+
137
+ if sub_confidence < 0.90 :
138
+ sub_prediction = "Others"
139
+ sub_confidence = 100- sub_confidence
140
+ description=None
141
+ else:
142
+ # Generate product description
143
+ description = generate_description(main_prediction, sub_prediction)
144
+
145
+ return f"Main Predicted Class: {main_prediction} (Confidence: {main_confidence:.4f})", \
146
+ f"Sub Predicted Class: {sub_prediction} (Confidence: {sub_confidence:.4f})", \
147
+ f"Product Description: {description}"
148
+
149
+ # Gradio interface (updated)
150
+ image_input = gr.Image(image_mode="RGB") # Removed shape argument
151
+ output_text = gr.Textbox()
152
+
153
+ gr.Interface(fn=classify_image, inputs=image_input, outputs=[output_text],
154
+ title="Main and Sub-Classifier System product description ",
155
+ description="Upload an image to classify whether it belongs to Clothing, Mobile Phones, or Soda Drinks. Based on the prediction, it will further classify within the subcategory and generate a detailed product description .",
156
+ theme="default").launch()