import torch from torch.nn import functional as F import torchvision.models as models from torchvision import transforms import requests preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456 , 0.406], std=[0.229, 0.224, 0.225] ) ]) response = requests.get("https://git.io/JJkYN") labels = response.text.split("\n") image_prediction_models = { 'resnet': models.resnet50, 'alexnet': models.alexnet, 'vgg': models.vgg16, 'squeezenet': models.squeezenet1_0, 'densenet': models.densenet161, 'inception': models.inception_v3, 'googlenet': models.googlenet, 'shufflenet': models.shufflenet_v2_x1_0, 'mobilenet': models.mobilenet_v2, 'resnext': models.resnext50_32x4d, 'wide_resnet': models.wide_resnet50_2, 'mnasnet': models.mnasnet1_0, 'efficientnet': models.efficientnet_b0, 'regnet': models.regnet_y_400mf, 'vit': models.vit_b_16, 'convnext': models.convnext_tiny } def load_pretrained_model(model_name): model_name_lower = model_name.lower() if model_name_lower in image_prediction_models: model_class = image_prediction_models[model_name_lower] model = model_class(pretrained=True) return model else: raise ValueError(f"Model {model_name} is not available for image prediction in torchvision.models") def get_model_names(models_dict): return [name.capitalize() for name in models_dict.keys()] model_list = get_model_names(image_prediction_models) def classify_image(input_image, selected_model): input_tensor = preprocess(input_image) input_batch = input_tensor.unsqueeze(0) model = load_pretrained_model(selected_model) if torch.cuda.is_available(): input_batch = input_batch.to('cuda') with torch.no_grad(): output = model(input_batch) probabilities = F.softmax(input = output[0] , dim = 0) top_prob, top_catid = torch.topk(probabilities, 5) confidences = {labels[top_catid[i].item()]: top_prob[i].item() for i in range(top_prob.size(0))} return confidences import gradio as gr interface = gr.Interface( fn=classify_image, inputs= [gr.Image(type='pil'), gr.Dropdown(model_list)], outputs=gr.Label(num_top_classes=5)) interface.launch()