emc / app.py
aboba2285214's picture
Update app.py
ffdf90a verified
raw
history blame contribute delete
No virus
870 Bytes
import torch
import torch.nn as nn
from torchvision import transforms, models
import gradio as gr
transformer = models.ResNet18_Weights.IMAGENET1K_V1.transforms()
class_names = ['anger', 'disgust', 'fear', 'happy', 'pain', 'sad']
classes_count = len(class_names)
model = models.resnet18(weights='DEFAULT')
model.fc = nn.Sequential(
nn.Linear(512, classes_count)
)
model.load_state_dict(torch.load('./model_param.pt', map_location=torch.device('cpu')), strict=False)
def predict(img):
img = transformer(img).unsqueeze(0)
model.eval()
with torch.inference_mode():
pred = torch.softmax(model(img), dim=1)
pred_and_labels = {class_names[i] : pred[0][i].item() for i in range(len(pred[0])) }
return pred_and_labels
app = gr.Interface(
predict,
gr.Image(type='pil'),
gr.Label(num_top_classes=classes_count)
)
app.launch()