Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import ViTForImageClassification, ViTFeatureExtractor | |
from PIL import Image | |
# Load model and feature extractor | |
model = ViTForImageClassification.from_pretrained('shahmi0519/mango_artificial', num_labels=2, ignore_mismatched_sizes=True) | |
feature_extractor = ViTFeatureExtractor.from_pretrained('shahmi0519/mango_artificial') | |
# Move to GPU if available | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = model.to(device) | |
model.eval() | |
# Class labels (modify according to your model) | |
class_labels = [ | |
"Artificial", | |
"Natural" | |
] | |
def predict_freshness(image): | |
# Preprocess image | |
inputs = feature_extractor(images=image, return_tensors="pt").to(device) | |
# Predict | |
model.eval() | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
predicted_class_idx = logits.argmax(-1).item() | |
# Get label | |
try: | |
label = class_labels[predicted_class_idx] | |
except IndexError: | |
label = f"Class {predicted_class_idx}" | |
return label | |
# Create Gradio interface | |
title = "Freshness Detector" | |
description = "Upload an image of fruit/vegetable to detect its freshness state" | |
examples = [ | |
["apple.jpeg"], | |
["banana.jpeg"], | |
["tomato.jpeg"] | |
] | |
iface = gr.Interface( | |
fn=predict_freshness, | |
inputs=gr.Image(type="pil", label="Upload Image"), | |
outputs=gr.Label(label="Freshness State"), | |
title=title, | |
description=description, | |
examples=examples | |
) | |
iface.launch(share=True) |