Junlinh commited on
Commit
2b5762c
·
1 Parent(s): 8b48602

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -0
app.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torchvision.transforms as transforms
3
+ from PIL import Image
4
+ import torch
5
+ from timm.models import create_model
6
+ def predict(input_img):
7
+ input_img = Image.fromarray(np.uint8(input_img))
8
+ model1 = create_model(
9
+ 'resnet50',
10
+ drop_rate=0.5,
11
+ num_classes=1,)
12
+ model2 = create_model(
13
+ 'resnet50',
14
+ drop_rate=0.5,
15
+ num_classes=1,)
16
+
17
+ loc = 'cuda:{}'.format(0)
18
+ checkpoint1 = torch.load("./machine_full_best.tar", map_location=loc)
19
+ model1.load_state_dict(checkpoint1['state_dict'])
20
+ checkpoint2 = torch.load("./human_full_best.tar", map_location=loc)
21
+ model2.load_state_dict(checkpoint2['state_dict'])
22
+
23
+ my_transform = transforms.Compose([
24
+ transforms.RandomResizedCrop(224, (1, 1)),
25
+ transforms.ToTensor(),
26
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
27
+ std=[0.229, 0.224, 0.225]),])
28
+
29
+ input_img = my_transform(input_img).view(1,3,224,224)
30
+ model1.eval()
31
+ model2.eval()
32
+ result1 = round(model1(input_img).item(), 3)
33
+ result2 = round(model2(input_img).item(), 3)
34
+ result = 'MachineMem score = ' + str(result1) + ', HumanMem score = ' + str(result2) +'.'
35
+ return result
36
+
37
+ demo = gr.Interface(predict, gr.Image(), "text")
38
+ demo.launch(debug = True)