cats-vs-dogs / app.py
Manu8's picture
Update app.py
c5df48e
raw
history blame contribute delete
No virus
1.01 kB
from PIL import Image
from torchvision import transforms
from transformers import AutoModelForImageClassification
import gradio as gr
import torch
num_classes = 2
def predict(inp):
inputs = data_transforms(inp)[None]
model.eval()
with torch.no_grad():
logits = model(inputs)['logits']
probs = torch.softmax(logits,dim=1)
confidences = {labels[i]: probs[0][i] for i in range(num_classes)}
return confidences
data_transforms = transforms.Compose([
transforms.Resize((224,224)), # Resize the images to a specific size
transforms.ToTensor(), # Convert images to tensors
#transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalize the image data
])
# Load model directly
model = AutoModelForImageClassification.from_pretrained("Manu8/vit_cats-vs-dogs", trust_remote_code=True)
labels = [
'cat','dog'
]
gr.Interface(fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=3)).launch()