etemkocaaslan's picture
Create app.py
2c85ae1 verified
raw
history blame
2.43 kB
import torch
from torch.nn import functional as F
import torchvision.models as models
from torchvision import transforms
import requests
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456 , 0.406],
std=[0.229, 0.224, 0.225]
)
])
response = requests.get("https://git.io/JJkYN")
labels = response.text.split("\n")
image_prediction_models = {
'resnet': models.resnet50,
'alexnet': models.alexnet,
'vgg': models.vgg16,
'squeezenet': models.squeezenet1_0,
'densenet': models.densenet161,
'inception': models.inception_v3,
'googlenet': models.googlenet,
'shufflenet': models.shufflenet_v2_x1_0,
'mobilenet': models.mobilenet_v2,
'resnext': models.resnext50_32x4d,
'wide_resnet': models.wide_resnet50_2,
'mnasnet': models.mnasnet1_0,
'efficientnet': models.efficientnet_b0,
'regnet': models.regnet_y_400mf,
'vit': models.vit_b_16,
'convnext': models.convnext_tiny
}
def load_pretrained_model(model_name):
model_name_lower = model_name.lower()
if model_name_lower in image_prediction_models:
model_class = image_prediction_models[model_name_lower]
model = model_class(pretrained=True)
return model
else:
raise ValueError(f"Model {model_name} is not available for image prediction in torchvision.models")
def get_model_names(models_dict):
return [name.capitalize() for name in models_dict.keys()]
model_list = get_model_names(image_prediction_models)
def classify_image(input_image, selected_model):
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0)
model = load_pretrained_model(selected_model)
if torch.cuda.is_available():
input_batch = input_batch.to('cuda')
with torch.no_grad():
output = model(input_batch)
probabilities = F.softmax(input = output[0] , dim = 0)
top_prob, top_catid = torch.topk(probabilities, 5)
confidences = {labels[top_catid[i].item()]: top_prob[i].item() for i in range(top_prob.size(0))}
return confidences
import gradio as gr
interface = gr.Interface(
fn=classify_image,
inputs= [gr.Image(type='pil'),
gr.Dropdown(model_list)],
outputs=gr.Label(num_top_classes=5))
interface.launch()