|
import os |
|
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache" |
|
|
|
from flask import Flask, request, render_template, jsonify |
|
from transformers import ViTForImageClassification, ViTFeatureExtractor |
|
import torch |
|
import torch.nn as nn |
|
import torchvision.transforms as transforms |
|
from PIL import Image |
|
import io |
|
|
|
app = Flask(__name__) |
|
|
|
|
|
model_name = "google/vit-base-patch16-224-in21k" |
|
model = ViTForImageClassification.from_pretrained(model_name) |
|
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name) |
|
|
|
|
|
num_classes = 7 |
|
model.classifier = nn.Linear(model.config.hidden_size, num_classes) |
|
model.load_state_dict(torch.load("skin_cancer_model.pth", map_location=torch.device('cpu'))) |
|
model.eval() |
|
|
|
|
|
class_labels = ['benign_keratosis-like_lesions', 'basal_cell_carcinoma', 'actinic_keratoses', 'vascular_lesions', 'melanocytic_Nevi', 'melanoma', 'dermatofibroma'] |
|
|
|
|
|
thresholds = [0.88134295, 0.43095806, 0.39622146, 0.90647435, 0.8128958, 0.05310565, 0.15926854] |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
]) |
|
|
|
@app.route('/') |
|
def index(): |
|
return render_template('index.html', appName="Skin Cancer Classification Application") |
|
|
|
def model_predict(image): |
|
image = transform(image).unsqueeze(0) |
|
with torch.no_grad(): |
|
outputs = model(image) |
|
return outputs |
|
|
|
@app.route('/predictApi', methods=["POST"]) |
|
def api(): |
|
try: |
|
if 'fileup' not in request.files: |
|
return jsonify({'Error': "Please try again. The Image doesn't exist"}) |
|
file = request.files.get('fileup') |
|
image = Image.open(io.BytesIO(file.read())) |
|
result = model_predict(image) |
|
|
|
probabilities = torch.softmax(result.logits, dim=1).cpu().numpy()[0] |
|
predicted_idx = torch.argmax(torch.tensor(probabilities)).item() |
|
max_prob = probabilities[predicted_idx] |
|
threshold = thresholds[predicted_idx] |
|
|
|
if max_prob < threshold: |
|
return jsonify({'Error': 'No cancer detected or benign lesion.'}) |
|
prediction = class_labels[predicted_idx] |
|
return jsonify({'prediction': prediction}) |
|
except Exception as e: |
|
return jsonify({'Error': 'An error occurred', 'Message': str(e)}) |
|
|
|
@app.route('/predict', methods=['GET', 'POST']) |
|
def predict(): |
|
if request.method == 'POST': |
|
try: |
|
if 'fileup' not in request.files: |
|
return render_template('index.html', prediction='No file selected.', appName="Skin Cancer Classification Application") |
|
file = request.files['fileup'] |
|
image = Image.open(io.BytesIO(file.read())) |
|
result = model_predict(image) |
|
|
|
probabilities = torch.softmax(result.logits, dim=1).cpu().numpy()[0] |
|
predicted_idx = torch.argmax(torch.tensor(probabilities)).item() |
|
max_prob = probabilities[predicted_idx] |
|
threshold = thresholds[predicted_idx] |
|
|
|
if max_prob < threshold: |
|
return render_template('index.html', prediction='No cancer detected or benign lesion.', appName="Skin Cancer Classification Application") |
|
prediction = class_labels[predicted_idx] |
|
return render_template('index.html', prediction=prediction, appName="Skin Cancer Classification Application") |
|
except Exception as e: |
|
return render_template('index.html', prediction='Error: ' + str(e), appName="Skin Cancer Classification Application") |
|
else: |
|
return render_template('index.html', appName="Skin Cancer Classification Application") |
|
|
|
if __name__ == '__main__': |
|
app.run(debug=True) |
|
|