aaronespasa commited on
Commit
c6185a5
1 Parent(s): 7829397

Gradio Application

Browse files
Files changed (2) hide show
  1. .gitignore +1 -0
  2. app.py +85 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ examples/
app.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from facenet_pytorch import MTCNN, InceptionResnetV1
5
+ import os
6
+ import numpy as np
7
+ from PIL import Image
8
+ import zipfile
9
+
10
+ with zipfile.ZipFile("examples.zip","r") as zip_ref:
11
+ zip_ref.extractall(".")
12
+
13
+ DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
14
+
15
+ mtcnn = MTCNN(
16
+ select_largest=False,
17
+ post_process=False,
18
+ device=DEVICE
19
+ ).to(DEVICE).eval()
20
+
21
+ model = InceptionResnetV1(
22
+ pretrained="vggface2",
23
+ classify=True,
24
+ num_classes=1,
25
+ device=DEVICE
26
+ )
27
+
28
+ checkpoint = torch.load("resnetinceptionv1_epoch_32.pth")
29
+ model.load_state_dict(checkpoint['model_state_dict'])
30
+ model.to(DEVICE)
31
+ model.eval()
32
+
33
+ EXAMPLES_FOLDER = 'examples'
34
+ examples_names = os.listdir(EXAMPLES_FOLDER)
35
+ examples = []
36
+ for example_name in examples_names:
37
+ example_path = os.path.join(EXAMPLES_FOLDER, example_name)
38
+ label = example_name.split('_')[0]
39
+ example = {
40
+ 'path': example_path,
41
+ 'label': label
42
+ }
43
+ examples.append(example)
44
+ np.random.shuffle(examples) # shuffle
45
+
46
+ def predict(input_image:Image.Image, true_label:str):
47
+ """Predict the label of the input_image"""
48
+ face = mtcnn(input_image)
49
+ if face is None:
50
+ raise Exception('No face detected')
51
+ face = face.unsqueeze(0) # add the batch dimension
52
+ face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False)
53
+
54
+ # convert the face into a numpy array to be able to plot it
55
+ face_image_to_plot = face.squeeze(0).permute(1, 2, 0).cpu().detach().int().numpy()
56
+
57
+ face = face.to(DEVICE)
58
+ face = face.to(torch.float32)
59
+ face = face / 255.0
60
+ with torch.no_grad():
61
+ output = torch.sigmoid(model(face).squeeze(0))
62
+ prediction = "real" if output.item() < 0.5 else "fake"
63
+
64
+ real_prediction = 1 - output.item()
65
+ fake_prediction = output.item()
66
+
67
+ confidences = {
68
+ 'real': real_prediction,
69
+ 'fake': fake_prediction
70
+ }
71
+ return confidences, true_label, face_image_to_plot
72
+
73
+ interface = gr.Interface(
74
+ fn=predict,
75
+ inputs=[
76
+ gr.inputs.Image(label="Input Image", type="pil"),
77
+ "text"
78
+ ],
79
+ outputs=[
80
+ gr.outputs.Label(label="Class"),
81
+ "text",
82
+ gr.outputs.Image(label="Face")
83
+ ],
84
+ examples=[[examples[i]["path"], examples[i]["label"]] for i in range(10)]
85
+ ).launch()