Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from torchvision import transforms, models | |
import gradio as gr | |
transformer = models.ResNet18_Weights.IMAGENET1K_V1.transforms() | |
class_names = ['anger', 'disgust', 'fear', 'happy', 'pain', 'sad'] | |
classes_count = len(class_names) | |
model = models.resnet18(weights='DEFAULT') | |
model.fc = nn.Sequential( | |
nn.Linear(512, classes_count) | |
) | |
model.load_state_dict(torch.load('./model_param.pt', map_location=torch.device('cpu')), strict=False) | |
def predict(img): | |
img = transformer(img).unsqueeze(0) | |
model.eval() | |
with torch.inference_mode(): | |
pred = torch.softmax(model(img), dim=1) | |
pred_and_labels = {class_names[i] : pred[0][i].item() for i in range(len(pred[0])) } | |
return pred_and_labels | |
app = gr.Interface( | |
predict, | |
gr.Image(type='pil'), | |
gr.Label(num_top_classes=classes_count) | |
) | |
app.launch() |