from flask import Flask, render_template, request, jsonify from PIL import Image import torch import io import base64 from torchvision import transforms from face_mask_detection import FaceMaskDetectionModel import numpy as np app = Flask(__name__) # Load the model model = FaceMaskDetectionModel() # Load the state dictionary model_state_dict = torch.load("models\\facemask_model_statedict1_f.pth", map_location=torch.device('cpu')) # Load the state dictionary into the model model.load_state_dict(model_state_dict) # Set the model to evaluation mode model.eval() # Define the pre-processing transform transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor() ]) # Define class labels class_labels = ['without mask', 'with mask'] @app.route('/') def index(): return render_template('index.html') @app.route('/predict', methods=['POST']) def predict(): try: # Get the image from the request image = request.files['image'] # Pre-process the image image_tensor = transform(Image.open(io.BytesIO(image.read())).convert('RGB')).unsqueeze(0) # Set the model to evaluation mode model.eval() # Make a prediction with torch.no_grad(): output = model(image_tensor) print("Output: ", output) # Convert the output to probabilities using softmax probabilities = torch.nn.functional.softmax(output[0], dim=0) print("Probabilities: ", probabilities) # Get the predicted class predicted_class = torch.argmax(probabilities).item() print("Predicted: ", predicted_class) # Get the probability for the predicted class predicted_probability = probabilities[predicted_class].item() # Define class labels class_labels = ['without mask', 'with mask'] print(f"Predicted Class: {class_labels[predicted_class]}") print(f"Probability: {predicted_probability:.4f}") # Return the prediction along with the uploaded image image_base64 = base64.b64encode(image.read()).decode('utf-8') return jsonify({'prediction': predicted_class, 'image': image_base64}) except Exception as e: return jsonify({'error': str(e)}), 500 if __name__ == '__main__': app.run(debug=True)