File size: 3,924 Bytes
4d9e0f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from flask import Flask, request, jsonify, render_template, send_file
from werkzeug.utils import secure_filename
import torch
from torchvision import transforms
import tensorflow as tf
from PIL import Image
import numpy as np
import io
import base64
import cv2
from model import tumor_model

# Initialize Flask app
app = Flask(__name__)

# Define the model paths
CLASSIFICATION_MODEL_PATH = 'models\\tumor_model_statedict_f.pth'
SEGMENTATION_MODEL_PATH = 'models\\unet_model.h5'

# Load the models
class MultiTaskModelWrapper:
    def __init__(self):
        self.segmentation_model = self.load_segmentation_model()
        self.classification_model = self.load_classification_model()
    
    def load_segmentation_model(self):
        # Load the pre-trained U-Net model
        model = tf.keras.models.load_model(SEGMENTATION_MODEL_PATH, custom_objects={'conv2d_transpose': tf.keras.layers.Conv2DTranspose})
        return model
    
    def load_classification_model(self):
        # Load the pre-trained Classification model
        tumor_model.load_state_dict(torch.load(CLASSIFICATION_MODEL_PATH, map_location=torch.device('cpu')))
        tumor_model.eval()
        return tumor_model
    
    def predict(self, image):
        # Classification prediction
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ])
        img_tensor = transform(image).unsqueeze(0)
        
        with torch.no_grad():
            classification_output = self.classification_model(img_tensor)
        
        class_probabilities = torch.nn.functional.softmax(classification_output, dim=1)
        class_label = torch.argmax(class_probabilities).item()
        probability = class_probabilities[0, class_label].item()
        
        class_names = {
            0: 'Glioma Tumor',
            1: 'Meningioma Tumor',
            2: 'No Tumor',
            3: 'Pituitary Tumor'
        }
        
        # Segmentation prediction
        img_array = np.array(image.resize((128, 128)))
        img_array = np.expand_dims(img_array, axis=0) / 255.0
        
        segmentation_output = self.segmentation_model.predict(img_array)
        segmentation_mask = (segmentation_output > 0.5).astype(np.uint8)[0, :, :, 0] * 255
        
        # Convert segmentation mask to base64
        mask_image = Image.fromarray(segmentation_mask.astype(np.uint8))
        buffer = io.BytesIO()
        mask_image.save(buffer, format='PNG')
        mask_image_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
        
        # Convert input image to base64
        input_image_buffer = io.BytesIO()
        image.save(input_image_buffer, format='PNG')
        input_image_base64 = base64.b64encode(input_image_buffer.getvalue()).decode('utf-8')
        
        return input_image_base64, mask_image_base64, class_names[class_label], probability

# Initialize the model wrapper
model_wrapper = MultiTaskModelWrapper()

@app.route('/')
def home():
    return render_template('index.html')

@app.route('/predict', methods=['POST'])
def predict():
    if 'file' not in request.files:
        return jsonify({'error': 'No file part'})
    file = request.files['file']
    if file.filename == '':
        return jsonify({'error': 'No selected file'})
    
    try:
        image = Image.open(file.stream).convert('RGB')
        input_image_base64, mask_image_base64, class_label, probability = model_wrapper.predict(image)
        
        return jsonify({
            'input_image': input_image_base64,
            'mask_image': mask_image_base64,
            'class_label': class_label,
            'probability': probability
        })

    except Exception as e:
        return jsonify({'error': str(e)})

if __name__ == '__main__':
    app.run(debug=True)