Files changed (1) hide show
  1. app.py +29 -0
app.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ from torchvision import transforms
5
+ from transformers import ViTForImageClassification
6
+
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+ model = ViTForImageClassification.from_pretrained('umutbozdag/plant-identity', num_labels=10, ignore_mismatched_sizes=True)
9
+ model.load_state_dict(torch.load('model.pth', map_location=device))
10
+ model.to(device)
11
+ model.eval()
12
+
13
+ # Define the prediction function
14
+ def predict_image(img):
15
+ transform = transforms.Compose([
16
+ transforms.Resize((224, 224)),
17
+ transforms.ToTensor(),
18
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
19
+ ])
20
+ img_t = transform(img).unsqueeze(0).to(device)
21
+ with torch.no_grad():
22
+ outputs = model(img_t).logits
23
+ _, predicted = torch.max(outputs, 1)
24
+ class_names = ["Aloe Vera", "Areca Palm", "Boston Fern", "Chinese evergreen", "Dracaena", "Money Tree", "Peace lily", "Rubber Plant", "Snake Plant", "ZZ Plant"]
25
+ return class_names[predicted.item()]
26
+
27
+ # Create a Gradio interface
28
+ interface = gr.Interface(fn=predict_image, inputs=gr.Image(type="pil"), outputs="text")
29
+ interface.launch(share = True)