KabeerAmjad commited on
Commit
954ac21
·
verified ·
1 Parent(s): 75a5b88

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -10
app.py CHANGED
@@ -3,13 +3,14 @@ import torchvision.transforms as transforms
3
  import torchvision.models as models
4
  from PIL import Image
5
  import json
 
6
 
7
  # Load the model with updated weights parameter
8
  model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
9
  model.eval() # Set model to evaluation mode
10
 
11
  # Load the model's custom state_dict
12
- model_path = 'path_to_your_model_file.pth'
13
  try:
14
  state_dict = torch.load(model_path, map_location=torch.device('cpu'))
15
  model.load_state_dict(state_dict)
@@ -29,13 +30,13 @@ preprocess = transforms.Compose([
29
  ])
30
 
31
  # Load labels
32
- with open("imagenet_classes.json") as f:
33
  labels = json.load(f)
34
 
35
  # Function to predict image class
36
- def predict(image_path):
37
- # Open the image file
38
- input_image = Image.open(image_path).convert("RGB")
39
 
40
  # Preprocess the image
41
  input_tensor = preprocess(input_image)
@@ -54,9 +55,16 @@ def predict(image_path):
54
  _, predicted_idx = torch.max(output, 1)
55
  predicted_class = labels[str(predicted_idx.item())]
56
 
57
- return predicted_class
58
 
59
- # Example usage
60
- image_path = 'path_to_your_image.jpg'
61
- predicted_class = predict(image_path)
62
- print(f"Predicted class: {predicted_class}")
 
 
 
 
 
 
 
 
3
  import torchvision.models as models
4
  from PIL import Image
5
  import json
6
+ import gradio as gr
7
 
8
  # Load the model with updated weights parameter
9
  model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
10
  model.eval() # Set model to evaluation mode
11
 
12
  # Load the model's custom state_dict
13
+ model_path = 'food_classification_model.pth'
14
  try:
15
  state_dict = torch.load(model_path, map_location=torch.device('cpu'))
16
  model.load_state_dict(state_dict)
 
30
  ])
31
 
32
  # Load labels
33
+ with open("config.json") as f:
34
  labels = json.load(f)
35
 
36
  # Function to predict image class
37
+ def predict(image):
38
+ # Convert the uploaded file to a PIL image
39
+ input_image = image.convert("RGB")
40
 
41
  # Preprocess the image
42
  input_tensor = preprocess(input_image)
 
55
  _, predicted_idx = torch.max(output, 1)
56
  predicted_class = labels[str(predicted_idx.item())]
57
 
58
+ return f"Predicted class: {predicted_class}"
59
 
60
+ # Set up the Gradio interface
61
+ iface = gr.Interface(
62
+ fn=predict,
63
+ inputs=gr.inputs.Image(type="pil"),
64
+ outputs="text",
65
+ title="Food Classification Model",
66
+ description="Upload an image of food to classify it."
67
+ )
68
+
69
+ # Launch the Gradio app
70
+ iface.launch()