cats-vs-dogs / app.py
Manu8's picture
Update app.py
2d13268
raw
history blame
979 Bytes
from PIL import Image
from torchvision import transforms
from transformers import AutoModelForImageClassification
import gradio as gr
import torch
def predict(inp):
inputs = data_transforms(inp)[None]
model.eval()
with torch.no_grad():
logits = model(inputs)
probs = torch.softmax(logits,dim=1)
confidences = {labels[i]: probs[0][i] for i in range(num_classes)}
return confidences
data_transforms = transforms.Compose([
transforms.Resize((224,224)), # Resize the images to a specific size
transforms.ToTensor(), # Convert images to tensors
#transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalize the image data
])
# Load model directly
model = AutoModelForImageClassification.from_pretrained("Manu8/vit_cats-vs-dogs", trust_remote_code=True)
labels = [
'cat','dog'
]
gr.Interface(fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=3)).launch()