aboba2285214 commited on
Commit
26f6124
1 Parent(s): 324c3cc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -0
app.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import transforms
4
+ import gradio as gr
5
+
6
+ transformer = models.ResNet18_Weights.IMAGENET1K_V1.transforms()
7
+ class_names = ['anger', 'disgust', 'fear', 'happy', 'pain', 'sad']
8
+ classes_count = len(class_names)
9
+
10
+ model = models.resnet18(weights='DEFAULT')
11
+ model.fc = nn.Sequential(
12
+ nn.Linear(512, classes_count)
13
+ )
14
+ model.load_state_dict(torch.load('./model_param.pt'), strict=False)
15
+
16
+ def predict(img):
17
+ img = transformer(img).unsqueeze(0)
18
+
19
+ model.eval()
20
+
21
+ with torch.inference_mode():
22
+ pred = torch.softmax(model(img), dim=1)
23
+
24
+ pred_and_labels = {class_names[i] : pred[0][i].item() for i in range(len(pred[0])) }
25
+
26
+ return pred_and_labels
27
+
28
+ app = gr.Interface(
29
+ predict,
30
+ gr.Image(type='pil'),
31
+ gr.Label(num_top_classes=classes_count)
32
+ )
33
+ app.launch()