from pathlib import Path from PIL import Image import numpy as np import torch import requests from io import BytesIO from torchvision.models import resnet18, ResNet18_Weights def predict(img_path = None) -> str: # Initialize the model and transform resnet_model = resnet18(weights=ResNet18_Weights.DEFAULT) resnet_transform = ResNet18_Weights.DEFAULT.transforms() # Load the image if img_path is None: image = Image.open("examples/steak.jpeg").convert("RGB") if isinstance(img_path, np.ndarray): img = Image.fromarray(img_path.astype("uint8"), "RGB") # img = effnet_b2_transform(img).unsqueeze(0) # Convert to tensor # img = torch.from_numpy(np.array(image)).permute(2, 0, 1) img = resnet_transform(img) # Inference resnet_model.eval() with torch.inference_mode(): logits = resnet_model(img.unsqueeze(0)) pred_class = torch.softmax(logits, dim=1).argmax(dim=1).item() predicted_label = ResNet18_Weights.DEFAULT.meta["categories"][pred_class] print(f"Predicted class: {predicted_label}") return predicted_label import numpy as np import gradio as gr demo = gr.Interface(predict, gr.Image(), "label", title="ResNet-18_1K 🚗", description="Upload an image to see classification probabilities based on ResNet-18 with 1K classes",) if __name__ == "__main__": demo.launch()