NehaBardeDUKE's picture
Update app.py
be8abf2
import torch
from torch import nn
import copy
import torchvision
from torchvision import transforms
from PIL import Image
from io import BytesIO
import tempfile
import gradio as gr
from gradio.components import Image, Textbox,Label
import os
def load_model(model_path):
checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
best_model_wts = copy.deepcopy(checkpoint)
model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
param.requires_grad = False
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)
model.load_state_dict(best_model_wts)
return model
model_path = 'model2.pth'
model = load_model(model_path)
class_names = ['artificial','human']
def model_pred(img):
transform=transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image_tensor=transform(img).unsqueeze(0)
model.eval()
with torch.no_grad():
pred_probs = torch.softmax(model(image_tensor), dim=1)
pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
return pred_labels_and_probs
inputs = gr.Image(type="pil")
example_list = [["examples/" + example] for example in os.listdir("examples")]
interface = gr.Interface(
fn=model_pred,
inputs=inputs,
outputs=gr.Label(num_top_classes=2, label="Classification"),
title="Original vs AI-generated art Classification",
description="Provide an image and get the predicted class label.",
examples=example_list)
interface.launch()