Spaces:
Sleeping
Sleeping
from pathlib import Path | |
from PIL import Image | |
import numpy as np | |
import torch | |
import requests | |
from io import BytesIO | |
from torchvision.models import resnet18, ResNet18_Weights | |
def predict(img_path = None) -> str: | |
# Initialize the model and transform | |
resnet_model = resnet18(weights=ResNet18_Weights.DEFAULT) | |
resnet_transform = ResNet18_Weights.DEFAULT.transforms() | |
# Load the image | |
if img_path is None: | |
image = Image.open("examples/steak.jpeg").convert("RGB") | |
if isinstance(img_path, np.ndarray): | |
img = Image.fromarray(img_path.astype("uint8"), "RGB") | |
# img = effnet_b2_transform(img).unsqueeze(0) | |
# Convert to tensor | |
# img = torch.from_numpy(np.array(image)).permute(2, 0, 1) | |
img = resnet_transform(img) | |
# Inference | |
resnet_model.eval() | |
with torch.inference_mode(): | |
logits = resnet_model(img.unsqueeze(0)) | |
pred_class = torch.softmax(logits, dim=1).argmax(dim=1).item() | |
predicted_label = ResNet18_Weights.DEFAULT.meta["categories"][pred_class] | |
print(f"Predicted class: {predicted_label}") | |
return predicted_label | |
import numpy as np | |
import gradio as gr | |
demo = gr.Interface(predict, | |
gr.Image(), | |
"label", | |
title="ResNet-18_1K π", | |
description="Upload an image to see classification probabilities based on ResNet-18 with 1K classes",) | |
if __name__ == "__main__": | |
demo.launch() | |