Spaces:
Sleeping
Sleeping
import torch | |
from torch.nn import functional as F | |
import torchvision.models as models | |
from torchvision import transforms | |
import requests | |
preprocess = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize( | |
mean=[0.485, 0.456 , 0.406], | |
std=[0.229, 0.224, 0.225] | |
) | |
]) | |
response = requests.get("https://git.io/JJkYN") | |
labels = response.text.split("\n") | |
image_prediction_models = { | |
'resnet': models.resnet50, | |
'alexnet': models.alexnet, | |
'vgg': models.vgg16, | |
'squeezenet': models.squeezenet1_0, | |
'densenet': models.densenet161, | |
'inception': models.inception_v3, | |
'googlenet': models.googlenet, | |
'shufflenet': models.shufflenet_v2_x1_0, | |
'mobilenet': models.mobilenet_v2, | |
'resnext': models.resnext50_32x4d, | |
'wide_resnet': models.wide_resnet50_2, | |
'mnasnet': models.mnasnet1_0, | |
'efficientnet': models.efficientnet_b0, | |
'regnet': models.regnet_y_400mf, | |
'vit': models.vit_b_16, | |
'convnext': models.convnext_tiny | |
} | |
def load_pretrained_model(model_name): | |
model_name_lower = model_name.lower() | |
if model_name_lower in image_prediction_models: | |
model_class = image_prediction_models[model_name_lower] | |
model = model_class(pretrained=True) | |
return model | |
else: | |
raise ValueError(f"Model {model_name} is not available for image prediction in torchvision.models") | |
def get_model_names(models_dict): | |
return [name.capitalize() for name in models_dict.keys()] | |
model_list = get_model_names(image_prediction_models) | |
def classify_image(input_image, selected_model): | |
input_tensor = preprocess(input_image) | |
input_batch = input_tensor.unsqueeze(0) | |
model = load_pretrained_model(selected_model) | |
if torch.cuda.is_available(): | |
input_batch = input_batch.to('cuda') | |
with torch.no_grad(): | |
output = model(input_batch) | |
probabilities = F.softmax(input = output[0] , dim = 0) | |
top_prob, top_catid = torch.topk(probabilities, 5) | |
confidences = {labels[top_catid[i].item()]: top_prob[i].item() for i in range(top_prob.size(0))} | |
return confidences | |
import gradio as gr | |
interface = gr.Interface( | |
fn=classify_image, | |
inputs= [gr.Image(type='pil'), | |
gr.Dropdown(model_list)], | |
outputs=gr.Label(num_top_classes=5)) | |
interface.launch() |