import os os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache" # Set cache directory to a writable location from flask import Flask, request, render_template, jsonify from transformers import ViTForImageClassification, ViTFeatureExtractor import torch import torch.nn as nn import torchvision.transforms as transforms from PIL import Image import io app = Flask(__name__) # Load the ViT model and its feature extractor model_name = "google/vit-base-patch16-224-in21k" model = ViTForImageClassification.from_pretrained(model_name) feature_extractor = ViTFeatureExtractor.from_pretrained(model_name) # Load the trained model weights num_classes = 7 model.classifier = nn.Linear(model.config.hidden_size, num_classes) model.load_state_dict(torch.load("skin_cancer_model.pth", map_location=torch.device('cpu'))) model.eval() # Define class labels class_labels = ['benign_keratosis-like_lesions', 'basal_cell_carcinoma', 'actinic_keratoses', 'vascular_lesions', 'melanocytic_Nevi', 'melanoma', 'dermatofibroma'] # Define optimal thresholds thresholds = [0.88134295, 0.43095806, 0.39622146, 0.90647435, 0.8128958, 0.05310565, 0.15926854] # Define image transformations transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), ]) @app.route('/') def index(): return render_template('index.html', appName="Skin Cancer Classification Application") def model_predict(image): image = transform(image).unsqueeze(0) # Add batch dimension with torch.no_grad(): outputs = model(image) return outputs @app.route('/predictApi', methods=["POST"]) def api(): try: if 'fileup' not in request.files: return jsonify({'Error': "Please try again. The Image doesn't exist"}) file = request.files.get('fileup') image = Image.open(io.BytesIO(file.read())) result = model_predict(image) probabilities = torch.softmax(result.logits, dim=1).cpu().numpy()[0] predicted_idx = torch.argmax(torch.tensor(probabilities)).item() max_prob = probabilities[predicted_idx] threshold = thresholds[predicted_idx] if max_prob < threshold: return jsonify({'Error': 'No cancer detected or benign lesion.'}) prediction = class_labels[predicted_idx] return jsonify({'prediction': prediction}) except Exception as e: return jsonify({'Error': 'An error occurred', 'Message': str(e)}) @app.route('/predict', methods=['GET', 'POST']) def predict(): if request.method == 'POST': try: if 'fileup' not in request.files: return render_template('index.html', prediction='No file selected.', appName="Skin Cancer Classification Application") file = request.files['fileup'] image = Image.open(io.BytesIO(file.read())) result = model_predict(image) probabilities = torch.softmax(result.logits, dim=1).cpu().numpy()[0] predicted_idx = torch.argmax(torch.tensor(probabilities)).item() max_prob = probabilities[predicted_idx] threshold = thresholds[predicted_idx] if max_prob < threshold: return render_template('index.html', prediction='No cancer detected or benign lesion.', appName="Skin Cancer Classification Application") prediction = class_labels[predicted_idx] return render_template('index.html', prediction=prediction, appName="Skin Cancer Classification Application") except Exception as e: return render_template('index.html', prediction='Error: ' + str(e), appName="Skin Cancer Classification Application") else: return render_template('index.html', appName="Skin Cancer Classification Application") if __name__ == '__main__': app.run(debug=True)