File size: 18,713 Bytes
ea2329d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
import torch
import torchvision
import gradio as gr
import numpy as np
import pandas as pd
from PIL import Image
import torch.nn as nn
from pathlib import Path
import cv2
from torchvision import transforms
from efficientnet_pytorch import EfficientNet
import logging
import warnings
from sklearn.preprocessing import StandardScaler
from typing import Optional, Dict, Any, Tuple
import json
import os
from datetime import datetime
import albumentations as A
from transformers import MarianMTModel, MarianTokenizer
import matplotlib.pyplot as plt
import seaborn as sns
import smtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
warnings.filterwarnings('ignore')

# Set up logging with more detailed configuration
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('skin_diagnostic.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

class ImageValidator:
    """Class for image validation and quality checking"""
    
    @staticmethod
    def validate_image(image: np.ndarray) -> Tuple[bool, str]:
        """
        Validate image quality and characteristics
        Returns: (is_valid, message)
        """
        try:
            # Check image dimensions
            if image.shape[0] < 224 or image.shape[1] < 224:
                return False, "Image resolution too low. Minimum 224x224 required."
            
            # Check if image is too dark or too bright
            brightness = np.mean(image)
            if brightness < 30:
                return False, "Image too dark. Please capture in better lighting."
            if brightness > 240:
                return False, "Image too bright. Please reduce exposure."
            
            # Check for blur
            laplacian_var = cv2.Laplacian(cv2.cvtColor(image, cv2.COLOR_RGB2GRAY), cv2.CV_64F).var()
            if laplacian_var < 100:
                return False, "Image is too blurry. Please provide a clearer image."
            
            # Check for color consistency
            color_std = np.std(image, axis=(0,1))
            if np.mean(color_std) < 20:
                return False, "Image lacks color variation. Please ensure proper lighting."
            
            return True, "Image validation successful"
            
        except Exception as e:
            logger.error(f"Image validation error: {str(e)}")
            return False, "Error during image validation"

class AdvancedImageAnalysis:
    """Class for sophisticated image analysis techniques"""
    
    def __init__(self):
        self.scaler = StandardScaler()
    
    def analyze_lesion(self, image: np.ndarray) -> Dict[str, float]:
        """
        Perform advanced analysis of skin lesion characteristics
        """
        try:
            # Convert to different color spaces
            hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
            lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
            
            # Extract features
            features = {
                'asymmetry': self._calculate_asymmetry(image),
                'border_irregularity': self._analyze_border(image),
                'color_variation': self._analyze_color(hsv),
                'diameter': self._estimate_diameter(image),
                'texture': self._analyze_texture(lab),
                'vascularity': self._analyze_vascularity(image),
            }
            
            return features
            
        except Exception as e:
            logger.error(f"Error in lesion analysis: {str(e)}")
            return {}
    
    def _calculate_asymmetry(self, image: np.ndarray) -> float:
        """Calculate asymmetry score of the lesion"""
        gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        _, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
        
        # Find contours
        contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        if not contours:
            return 0.0
            
        # Get largest contour
        largest_contour = max(contours, key=cv2.contourArea)
        
        # Calculate moments
        moments = cv2.moments(largest_contour)
        if moments['m00'] == 0:
            return 0.0
            
        # Calculate center of mass
        cx = moments['m10'] / moments['m00']
        cy = moments['m01'] / moments['m00']
        
        return float(cv2.matchShapes(largest_contour, cv2.flip(largest_contour, 1), 1, 0.0))
    
    def _analyze_border(self, image: np.ndarray) -> float:
        """Analyze border irregularity"""
        gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        _, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
        
        contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        if not contours:
            return 0.0
            
        largest_contour = max(contours, key=cv2.contourArea)
        perimeter = cv2.arcLength(largest_contour, True)
        area = cv2.contourArea(largest_contour)
        
        if area == 0:
            return 0.0
            
        circularity = 4 * np.pi * area / (perimeter * perimeter)
        return 1 - circularity
    
    def _analyze_color(self, hsv: np.ndarray) -> float:
        """Analyze color variation in the lesion"""
        return float(np.std(hsv[:,:,0]))
    
    def _estimate_diameter(self, image: np.ndarray) -> float:
        """Estimate lesion diameter"""
        gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        _, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
        
        contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        if not contours:
            return 0.0
            
        largest_contour = max(contours, key=cv2.contourArea)
        _, _, w, h = cv2.boundingRect(largest_contour)
        return max(w, h)
    
    def _analyze_texture(self, lab: np.ndarray) -> float:
        """Analyze texture patterns"""
        gray = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
        gray = cv2.cvtColor(gray, cv2.COLOR_BGR2GRAY)
        
        # Calculate GLCM features
        glcm = cv2.calcHist([gray], [0], None, [16], [0,256])
        glcm = glcm.flatten() / glcm.sum()
        
        # Calculate entropy
        entropy = -np.sum(glcm * np.log2(glcm + 1e-7))
        return float(entropy)
    
    def _analyze_vascularity(self, image: np.ndarray) -> float:
        """Analyze vascular patterns"""
        # Extract red channel
        red_channel = image[:,:,0]
        return float(np.percentile(red_channel, 95) - np.percentile(red_channel, 5))

class SkinDiagnosticSystem:
    def __init__(self, model_path: Optional[str] = None):
        # Define classes and risk levels
        self.classes = [
            'Melanocytic nevi',
            'Melanoma',
            'Benign keratosis-like lesions',
            'Basal cell carcinoma',
            'Actinic keratoses',
            'Vascular lesions',
            'Dermatofibroma'
        ]
        
        self.risk_levels = {
            'Melanoma': 'High',
            'Basal cell carcinoma': 'High',
            'Actinic keratoses': 'Moderate',
            'Vascular lesions': 'Low to Moderate',
            'Benign keratosis-like lesions': 'Low',
            'Melanocytic nevi': 'Low',
            'Dermatofibroma': 'Low'
        }
        
        # Initialize components
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.image_validator = ImageValidator()
        self.image_analyzer = AdvancedImageAnalysis()
        
        # Load model
        self.model = self._load_model(model_path)
        self.transform = self._get_transforms()
        
        # Load medical context
        self.medical_context = self._load_medical_context()
    
    def _load_model(self, model_path: Optional[str]) -> nn.Module:
        """Load model with checkpointing support"""
        try:
            model = EfficientNet.from_pretrained('efficientnet-b4')
            num_ftrs = model._fc.in_features
            model._fc = nn.Sequential(
                nn.Linear(num_ftrs, 512),
                nn.ReLU(),
                nn.Dropout(0.2),
                nn.Linear(512, len(self.classes))
            )
            
            if model_path and os.path.exists(model_path):
                logger.info(f"Loading model checkpoint from {model_path}")
                checkpoint = torch.load(model_path, map_location=self.device)
                model.load_state_dict(checkpoint['model_state_dict'])
                logger.info(f"Model checkpoint loaded. Epoch: {checkpoint['epoch']}")
            
            model = model.to(self.device)
            model.eval()
            return model
            
        except Exception as e:
            logger.error(f"Error loading model: {str(e)}")
            raise
    
    def _get_transforms(self) -> transforms.Compose:
        """Get image transformations"""
        return transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])
        ])
    
    def _load_medical_context(self) -> Dict[str, Any]:
        """Load medical context and warnings"""
        return {
            'Melanoma': {
                'description': 'A serious form of skin cancer that begins in melanocytes.',
                'warning': 'URGENT: Immediate medical attention required. This is a potentially serious condition.',
                'risk_factors': [
                    'UV exposure',
                    'Fair skin',
                    'Family history',
                    'Multiple moles'
                ],
                'follow_up': 'Immediate dermatologist consultation required'
            },
            'Basal cell carcinoma': {
                'description': 'The most common type of skin cancer.',
                'warning': 'Medical attention required. While typically slow-growing, treatment is necessary.',
                'risk_factors': [
                    'Sun exposure',
                    'Fair skin',
                    'Age over 50',
                    'Prior radiation therapy'
                ],
                'follow_up': 'Schedule dermatologist appointment within 1-2 weeks'
            },
            # Add entries for other conditions...
        }
    
    def save_checkpoint(self, epoch: int, optimizer: torch.optim.Optimizer, loss: float) -> None:
        """Save model checkpoint"""
        checkpoint_dir = Path('checkpoints')
        checkpoint_dir.mkdir(exist_ok=True)
        
        checkpoint_path = checkpoint_dir / f'model_checkpoint_epoch_{epoch}.pth'
        torch.save({
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
        }, checkpoint_path)
        
        logger.info(f"Checkpoint saved: {checkpoint_path}")
    
    def analyze_image(self, image: np.ndarray) -> Dict[str, Any]:
        """Main analysis function with validation and advanced analysis"""
        try:
            # Validate image
            is_valid, validation_message = self.image_validator.validate_image(image)
            if not is_valid:
                return {'error': validation_message}
            
            # Convert to PIL Image
            pil_image = Image.fromarray(image)
            
            # Prepare image for model
            img_tensor = self.transform(pil_image).unsqueeze(0).to(self.device)
            
            # Get model predictions
            with torch.no_grad():
                outputs = self.model(img_tensor)
                probs = torch.nn.functional.softmax(outputs, dim=1)
            
            # Get predicted class and probability
            pred_prob, pred_idx = torch.max(probs, 1)
            condition = self.classes[pred_idx]
            confidence = pred_prob.item() * 100
            
            # Perform advanced image analysis
            analysis_results = self.image_analyzer.analyze_lesion(image)
            
            # Get medical context
            medical_info = self.medical_context.get(condition, {})
            
            # Prepare response
            response = {
                'condition': condition,
                'confidence': confidence,
                'risk_level': self.risk_levels.get(condition, 'Unknown'),
                'analysis': analysis_results,
                'medical_context': medical_info,
                'warning': medical_info.get('warning', ''),
                'timestamp': datetime.now().isoformat()
            }
            
            # Log analysis results
            logger.info(f"Analysis completed for condition: {condition} (confidence: {confidence:.2f}%)")
            
            return response
            
        except Exception as e:
            logger.error(f"Error in image analysis: {str(e)}")
            return {'error': 'Analysis failed. Please try again.'}

def create_gradio_interface():
    system = SkinDiagnosticSystem()
    
    # Load translation models
    translation_models = {
        'hi': ('Helsinki-NLP/opus-mt-en-hi', MarianTokenizer, MarianMTModel),
        'ta': ('Helsinki-NLP/opus-mt-en-ta', MarianTokenizer, MarianMTModel),
        'te': ('Helsinki-NLP/opus-mt-en-te', MarianTokenizer, MarianMTModel),
        'bn': ('Helsinki-NLP/opus-mt-en-bn', MarianTokenizer, MarianMTModel),
        'mr': ('Helsinki-NLP/opus-mt-en-mr', MarianTokenizer, MarianMTModel),
        'pa': ('Helsinki-NLP/opus-mt-en-pa', MarianTokenizer, MarianMTModel),
        'gu': ('Helsinki-NLP/opus-mt-en-gu', MarianTokenizer, MarianMTModel),
        'kn': ('Helsinki-NLP/opus-mt-en-kn', MarianTokenizer, MarianMTModel),
        'ml': ('Helsinki-NLP/opus-mt-en-ml', MarianTokenizer, MarianMTModel),
    }
    
    def process_image(image, language, email=None):
        result = system.analyze_image(image)
        
        if 'error' in result:
            return f"Error: {result['error']}"
        
        # Format detailed output
        output = "ANALYSIS RESULTS\n" + "="*50 + "\n\n"
        
        # Condition and Risk Level
        output += f"Detected Condition: {result['condition']}\n"
        output += f"Confidence: {result['confidence']:.2f}%\n"
        output += f"Risk Level: {result['risk_level']}\n\n"
        
        # Warning (if any)
        if result['warning']:
            output += f"⚠️ WARNING ⚠️\n{result['warning']}\n\n"
        
        # Detailed Analysis
        output += "Detailed Analysis:\n" + "-"*20 + "\n"
        for metric, value in result['analysis'].items():
            output += f"{metric}: {value:.2f}\n"
        
        # Medical Context
        if 'medical_context' in result and result['medical_context']:
            output += "\nMedical Context:\n" + "-"*20 + "\n"
            context = result['medical_context']
            output += f"Description: {context.get('description', 'N/A')}\n"
            
            if 'risk_factors' in context:
                output += "\nRisk Factors:\n"
                for factor in context['risk_factors']:
                    output += f"- {factor}\n"
            
            if 'follow_up' in context:
                output += f"\nRecommended Follow-up:\n{context['follow_up']}\n"
        
        # Timestamp
        output += f"\nAnalysis Timestamp: {result['timestamp']}\n"
        
        # Disclaimer
        output += "\n" + "="*50 + "\n"
        output += "DISCLAIMER: This analysis is for informational purposes only and should not replace professional medical advice. Please consult a qualified healthcare provider for proper diagnosis and treatment."
        
        # Translate output to the selected language
        if language != 'en':
            model_name, tokenizer_class, model_class = translation_models[language]
            tokenizer = tokenizer_class.from_pretrained(model_name)
            model = model_class.from_pretrained(model_name)
            inputs = tokenizer(output, return_tensors="pt", padding=True, truncation=True)
            translated = model.generate(**inputs)
            translated_output = tokenizer.decode(translated[0], skip_special_tokens=True)
        else:
            translated_output = output
        
        # Send email if provided
        if email:
            send_email(email, translated_output)
        
        return translated_output

    def send_email(to_email, message):
        from_email = "your_email@example.com"
        password = "your_password"
        
        msg = MIMEMultipart()
        msg['From'] = from_email
        msg['To'] = to_email
        msg['Subject'] = "Skin Lesion Analysis Results"
        
        msg.attach(MIMEText(message, 'plain'))
        
        server = smtplib.SMTP('smtp.example.com', 587)
        server.starttls()
        server.login(from_email, password)
        server.sendmail(from_email, to_email, msg.as_string())
        server.quit()

    # Create enhanced Gradio interface with additional features
    iface = gr.Interface(
        fn=process_image,
        inputs=[
            gr.Image(type="numpy", label="Upload Skin Image"),
            gr.Dropdown(choices=["en", "hi", "ta", "te", "bn", "mr", "pa", "gu", "kn", "ml"], label="Select Language"),
            gr.Textbox(label="Email (optional)", placeholder="Enter your email to receive results")
        ],
        outputs=[
            gr.Textbox(label="Analysis Results", lines=20)
        ],
        title="Advanced Skin Lesion Analysis System",
        description="""
        This system analyzes skin lesions using advanced computer vision and deep learning techniques.
        
        Key Features:
        - Lesion classification based on the HAM10000 dataset
        - Advanced image quality validation
        - Detailed analysis of lesion characteristics
        - Medical context and risk assessment
        - Option to receive results via email
        
        Important: This tool is for educational purposes only and should not replace professional medical diagnosis.
        """,
        examples=[
            ["example_melanoma.jpg", "en", ""],
            ["example_nevus.jpg", "hi", ""],
            ["example_bcc.jpg", "ta", ""]
        ],
        analytics_enabled=False,
    )
    
    return iface

iface = create_gradio_interface()
iface.launch(
    server_name="0.0.0.0",
    server_port=7860,
    share=True,
)