aryanwinningson commited on
Commit
65fc8dd
1 Parent(s): bf34735

added api-keys

Browse files
Files changed (1) hide show
  1. main.py +86 -80
main.py CHANGED
@@ -1,80 +1,86 @@
1
- import torch
2
- import torch.nn.functional as F
3
- from facenet_pytorch import MTCNN, InceptionResnetV1
4
- import numpy as np
5
- from PIL import Image
6
- import cv2
7
- from pytorch_grad_cam import GradCAM
8
- from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
9
- from pytorch_grad_cam.utils.image import show_cam_on_image
10
- from fastapi import FastAPI, File, UploadFile
11
- from fastapi.responses import JSONResponse
12
- from io import BytesIO
13
-
14
- app = FastAPI()
15
-
16
- DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
17
-
18
- mtcnn = MTCNN(select_largest=False, post_process=False, device=DEVICE).eval()
19
-
20
- model = InceptionResnetV1(classify=True, num_classes=1, device=DEVICE)
21
- model_path = "Model/resnetinceptionv1_epoch_32.pth"
22
- checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
23
- model.load_state_dict(checkpoint['model_state_dict'])
24
- model.to(DEVICE)
25
- model.eval()
26
-
27
- def predict(input_image: Image.Image):
28
- """Predict the label of the input_image"""
29
- if input_image.mode == 'RGBA':
30
- input_image = input_image.convert('RGB')
31
- face = mtcnn(input_image)
32
- if face is None:
33
- raise Exception('No face detected')
34
- face = face.unsqueeze(0) # add the batch dimension
35
- face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False)
36
-
37
- prev_face = face.squeeze(0).permute(1, 2, 0).cpu().detach().int().numpy()
38
- prev_face = prev_face.astype('uint8')
39
-
40
- face = face.to(DEVICE)
41
- face = face.to(torch.float32)
42
- face = face / 255.0
43
- face_image_to_plot = face.squeeze(0).permute(1, 2, 0).cpu().detach().int().numpy()
44
-
45
- target_layers = [model.block8.branch1[-1]]
46
- cam = GradCAM(model=model, target_layers=target_layers)
47
- targets = [ClassifierOutputTarget(0)]
48
-
49
- grayscale_cam = cam(input_tensor=face, targets=targets, eigen_smooth=True)
50
- grayscale_cam = grayscale_cam[0, :]
51
- visualization = show_cam_on_image(face_image_to_plot, grayscale_cam, use_rgb=True)
52
- face_with_mask = cv2.addWeighted(prev_face, 1, visualization, 0.5, 0)
53
-
54
- with torch.no_grad():
55
- output = torch.sigmoid(model(face).squeeze(0))
56
- prediction = "real" if output.item() < 0.5 else "fake"
57
-
58
- real_prediction = 1 - output.item()
59
- fake_prediction = output.item()
60
-
61
- confidences = {
62
- 'real': real_prediction,
63
- 'fake': fake_prediction
64
- }
65
- return confidences, prediction, face_with_mask
66
-
67
- @app.post("/predict")
68
- async def predict_api(file: UploadFile = File(...)):
69
- image = Image.open(BytesIO(await file.read()))
70
- try:
71
- confidences, prediction, face_with_mask = predict(image)
72
- _, buffer = cv2.imencode('.jpg', face_with_mask)
73
- face_with_mask_encoded = buffer.tobytes()
74
- return JSONResponse(content={
75
- "confidences": confidences,
76
- "prediction": prediction,
77
- "face_with_mask": face_with_mask_encoded.hex()
78
- })
79
- except Exception as e:
80
- return JSONResponse(content={"error": str(e)}, status_code=400)
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from facenet_pytorch import MTCNN, InceptionResnetV1
4
+ import numpy as np
5
+ from PIL import Image
6
+ import cv2
7
+ from pytorch_grad_cam import GradCAM
8
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
9
+ from pytorch_grad_cam.utils.image import show_cam_on_image
10
+ from fastapi import FastAPI, File, UploadFile, Depends, HTTPException, Header
11
+ from fastapi.responses import JSONResponse
12
+ from io import BytesIO
13
+
14
+ app = FastAPI()
15
+
16
+ DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
17
+
18
+ mtcnn = MTCNN(select_largest=False, post_process=False, device=DEVICE).eval()
19
+
20
+ model = InceptionResnetV1(classify=True, num_classes=1, device=DEVICE)
21
+ model_path = "Model/resnetinceptionv1_epoch_32.pth"
22
+ checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
23
+ model.load_state_dict(checkpoint['model_state_dict'])
24
+ model.to(DEVICE)
25
+ model.eval()
26
+
27
+ API_KEY = "c50dd5ady0uRL0rdnSaVyrArYaN161edb06af8"
28
+
29
+ def get_api_key(api_key: str = Header(...)):
30
+ if api_key != API_KEY:
31
+ raise HTTPException(status_code=403, detail="Could not validate credentials")
32
+
33
+ def predict(input_image: Image.Image):
34
+ """Predict the label of the input_image"""
35
+ if input_image.mode == 'RGBA':
36
+ input_image = input_image.convert('RGB')
37
+ face = mtcnn(input_image)
38
+ if face is None:
39
+ raise Exception('No face detected')
40
+ face = face.unsqueeze(0) # add the batch dimension
41
+ face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False)
42
+
43
+ prev_face = face.squeeze(0).permute(1, 2, 0).cpu().detach().int().numpy()
44
+ prev_face = prev_face.astype('uint8')
45
+
46
+ face = face.to(DEVICE)
47
+ face = face.to(torch.float32)
48
+ face = face / 255.0
49
+ face_image_to_plot = face.squeeze(0).permute(1, 2, 0).cpu().detach().int().numpy()
50
+
51
+ target_layers = [model.block8.branch1[-1]]
52
+ cam = GradCAM(model=model, target_layers=target_layers)
53
+ targets = [ClassifierOutputTarget(0)]
54
+
55
+ grayscale_cam = cam(input_tensor=face, targets=targets, eigen_smooth=True)
56
+ grayscale_cam = grayscale_cam[0, :]
57
+ visualization = show_cam_on_image(face_image_to_plot, grayscale_cam, use_rgb=True)
58
+ face_with_mask = cv2.addWeighted(prev_face, 1, visualization, 0.5, 0)
59
+
60
+ with torch.no_grad():
61
+ output = torch.sigmoid(model(face).squeeze(0))
62
+ prediction = "real" if output.item() < 0.5 else "fake"
63
+
64
+ real_prediction = 1 - output.item()
65
+ fake_prediction = output.item()
66
+
67
+ confidences = {
68
+ 'real': real_prediction,
69
+ 'fake': fake_prediction
70
+ }
71
+ return confidences, prediction, face_with_mask
72
+
73
+ @app.post("/predict")
74
+ async def predict_api(file: UploadFile = File(...), api_key: str = Depends(get_api_key)):
75
+ image = Image.open(BytesIO(await file.read()))
76
+ try:
77
+ confidences, prediction, face_with_mask = predict(image)
78
+ _, buffer = cv2.imencode('.jpg', face_with_mask)
79
+ face_with_mask_encoded = buffer.tobytes()
80
+ return JSONResponse(content={
81
+ "confidences": confidences,
82
+ "prediction": prediction,
83
+ "face_with_mask": face_with_mask_encoded.hex()
84
+ })
85
+ except Exception as e:
86
+ return JSONResponse(content={"error": str(e)}, status_code=400)