hassaanik's picture
Upload 3 files
5730892 verified
from flask import Flask, render_template, request, jsonify
from PIL import Image
import torch
import io
import base64
from torchvision import transforms
from face_mask_detection import FaceMaskDetectionModel
import numpy as np
app = Flask(__name__)
# Load the model
model = FaceMaskDetectionModel()
# Load the state dictionary
model_state_dict = torch.load("models\\facemask_model_statedict1_f.pth", map_location=torch.device('cpu'))
# Load the state dictionary into the model
model.load_state_dict(model_state_dict)
# Set the model to evaluation mode
model.eval()
# Define the pre-processing transform
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
# Define class labels
class_labels = ['without mask', 'with mask']
@app.route('/')
def index():
return render_template('index.html')
@app.route('/predict', methods=['POST'])
def predict():
try:
# Get the image from the request
image = request.files['image']
# Pre-process the image
image_tensor = transform(Image.open(io.BytesIO(image.read())).convert('RGB')).unsqueeze(0)
# Set the model to evaluation mode
model.eval()
# Make a prediction
with torch.no_grad():
output = model(image_tensor)
print("Output: ", output)
# Convert the output to probabilities using softmax
probabilities = torch.nn.functional.softmax(output[0], dim=0)
print("Probabilities: ", probabilities)
# Get the predicted class
predicted_class = torch.argmax(probabilities).item()
print("Predicted: ", predicted_class)
# Get the probability for the predicted class
predicted_probability = probabilities[predicted_class].item()
# Define class labels
class_labels = ['without mask', 'with mask']
print(f"Predicted Class: {class_labels[predicted_class]}")
print(f"Probability: {predicted_probability:.4f}")
# Return the prediction along with the uploaded image
image_base64 = base64.b64encode(image.read()).decode('utf-8')
return jsonify({'prediction': predicted_class, 'image': image_base64})
except Exception as e:
return jsonify({'error': str(e)}), 500
if __name__ == '__main__':
app.run(debug=True)