KabeerAmjad commited on
Commit
75a5b88
·
verified ·
1 Parent(s): f63495a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -55
app.py CHANGED
@@ -1,66 +1,62 @@
1
- import gradio as gr
2
  import torch
3
- from torch import nn
4
- from torchvision import models, transforms
5
  from PIL import Image
6
- import os
7
-
8
- # Define the model path
9
- model_path = "food_classification_model.pth"
10
- huggingface_model_url = "https://huggingface.co/KabeerAmjad/food_classification_model/resolve/main/food_classification_model.pth"
11
-
12
- # Download the model from Hugging Face if it doesn't exist locally
13
- if not os.path.exists(model_path):
14
- import requests
15
- response = requests.get(huggingface_model_url)
16
- with open(model_path, "wb") as f:
17
- f.write(response.content)
18
-
19
- # Load the ResNet50 model
20
- model = models.resnet50(pretrained=False) # Don't load pre-trained weights here
21
- model.fc = nn.Linear(model.fc.in_features, 11) # Adjust the output layer to match your number of classes
22
-
23
- # Load the saved model weights
24
- model.load_state_dict(torch.load(model_path))
25
- model.eval() # Set the model to evaluation mode
26
-
27
- # Define the same preprocessing used during training
28
- transform = transforms.Compose([
29
- transforms.Resize((224, 224)),
30
  transforms.ToTensor(),
31
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
 
 
 
32
  ])
33
 
34
- # Define the prediction function
35
- def classify_image(img):
 
 
 
 
 
 
 
36
  # Preprocess the image
37
- img = transform(img).unsqueeze(0) # Add batch dimension
 
 
 
 
 
 
38
 
39
- # Make prediction
40
  with torch.no_grad():
41
- outputs = model(img)
42
- probs = torch.softmax(outputs, dim=-1)
43
 
44
- # Get the label with the highest probability
45
- top_label = probs.argmax().item() # Get the index of the highest probability
46
-
47
- # Map label index to the actual class name
48
- label_mapping = {
49
- 0: "apple_pie", 1: "cheesecake", 2: "chicken_curry", 3: "french_fries",
50
- 4: "fried_rice", 5: "hamburger", 6: "hot_dog", 7: "ice_cream",
51
- 8: "omelette", 9: "pizza", 10: "sushi"
52
- }
53
- return label_mapping[top_label]
54
-
55
- # Create the Gradio interface
56
- iface = gr.Interface(
57
- fn=classify_image,
58
- inputs=gr.Image(type="pil"),
59
- outputs="text",
60
- title="Food Image Classification",
61
- description="Upload an image to classify if it’s an apple pie, etc."
62
- )
63
 
64
- # Launch the app
65
- iface.launch()
66
 
 
 
 
 
 
 
1
  import torch
2
+ import torchvision.transforms as transforms
3
+ import torchvision.models as models
4
  from PIL import Image
5
+ import json
6
+
7
+ # Load the model with updated weights parameter
8
+ model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
9
+ model.eval() # Set model to evaluation mode
10
+
11
+ # Load the model's custom state_dict
12
+ model_path = 'path_to_your_model_file.pth'
13
+ try:
14
+ state_dict = torch.load(model_path, map_location=torch.device('cpu'))
15
+ model.load_state_dict(state_dict)
16
+ except RuntimeError as e:
17
+ print("Error loading state_dict:", e)
18
+ print("Ensure that the saved model architecture matches ResNet50.")
19
+
20
+ # Define the image transformations
21
+ preprocess = transforms.Compose([
22
+ transforms.Resize(256),
23
+ transforms.CenterCrop(224),
 
 
 
 
 
24
  transforms.ToTensor(),
25
+ transforms.Normalize(
26
+ mean=[0.485, 0.456, 0.406],
27
+ std=[0.229, 0.224, 0.225],
28
+ ),
29
  ])
30
 
31
+ # Load labels
32
+ with open("imagenet_classes.json") as f:
33
+ labels = json.load(f)
34
+
35
+ # Function to predict image class
36
+ def predict(image_path):
37
+ # Open the image file
38
+ input_image = Image.open(image_path).convert("RGB")
39
+
40
  # Preprocess the image
41
+ input_tensor = preprocess(input_image)
42
+ input_batch = input_tensor.unsqueeze(0) # Add batch dimension
43
+
44
+ # Check if a GPU is available and move the input and model to GPU
45
+ if torch.cuda.is_available():
46
+ input_batch = input_batch.to('cuda')
47
+ model.to('cuda')
48
 
49
+ # Perform inference
50
  with torch.no_grad():
51
+ output = model(input_batch)
 
52
 
53
+ # Get the predicted class with the highest score
54
+ _, predicted_idx = torch.max(output, 1)
55
+ predicted_class = labels[str(predicted_idx.item())]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
+ return predicted_class
 
58
 
59
+ # Example usage
60
+ image_path = 'path_to_your_image.jpg'
61
+ predicted_class = predict(image_path)
62
+ print(f"Predicted class: {predicted_class}")