import torch |
import torchvision |
import gradio as gr |
import numpy as np |
import pandas as pd |
from PIL import Image |
import torch.nn as nn |
from pathlib import Path |
import cv2 |
from torchvision import transforms |
from efficientnet_pytorch import EfficientNet |
import logging |
import warnings |
from sklearn.preprocessing import StandardScaler |
from typing import Optional, Dict, Any, Tuple |
import json |
import os |
from datetime import datetime |
import albumentations as A |
from transformers import MarianMTModel, MarianTokenizer |
import matplotlib.pyplot as plt |
import seaborn as sns |
import smtplib |
from email.mime.text import MIMEText |
from email.mime.multipart import MIMEMultipart |
warnings.filterwarnings('ignore') |
logging.basicConfig( |
level=logging.INFO, |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
handlers=[ |
logging.FileHandler('skin_diagnostic.log'), |
logging.StreamHandler() |
] |
) |
logger = logging.getLogger(__name__) |
class ImageValidator: |
"""Class for image validation and quality checking""" |
@staticmethod |
def validate_image(image: np.ndarray) -> Tuple[bool, str]: |
""" |
Validate image quality and characteristics |
Returns: (is_valid, message) |
""" |
try: |
if image.shape[0] < 224 or image.shape[1] < 224: |
return False, "Image resolution too low. Minimum 224x224 required." |
brightness = np.mean(image) |
if brightness < 30: |
return False, "Image too dark. Please capture in better lighting." |
if brightness > 240: |
return False, "Image too bright. Please reduce exposure." |
laplacian_var = cv2.Laplacian(cv2.cvtColor(image, cv2.COLOR_RGB2GRAY), cv2.CV_64F).var() |
if laplacian_var < 100: |
return False, "Image is too blurry. Please provide a clearer image." |
color_std = np.std(image, axis=(0,1)) |
if np.mean(color_std) < 20: |
return False, "Image lacks color variation. Please ensure proper lighting." |
return True, "Image validation successful" |
except Exception as e: |
logger.error(f"Image validation error: {str(e)}") |
return False, "Error during image validation" |
class AdvancedImageAnalysis: |
"""Class for sophisticated image analysis techniques""" |
def __init__(self): |
self.scaler = StandardScaler() |
def analyze_lesion(self, image: np.ndarray) -> Dict[str, float]: |
""" |
Perform advanced analysis of skin lesion characteristics |
""" |
try: |
hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV) |
lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB) |
features = { |
'asymmetry': self._calculate_asymmetry(image), |
'border_irregularity': self._analyze_border(image), |
'color_variation': self._analyze_color(hsv), |
'diameter': self._estimate_diameter(image), |
'texture': self._analyze_texture(lab), |
'vascularity': self._analyze_vascularity(image), |
} |
return features |
except Exception as e: |
logger.error(f"Error in lesion analysis: {str(e)}") |
return {} |
def _calculate_asymmetry(self, image: np.ndarray) -> float: |
"""Calculate asymmetry score of the lesion""" |
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) |
_, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) |
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
if not contours: |
return 0.0 |
largest_contour = max(contours, key=cv2.contourArea) |
moments = cv2.moments(largest_contour) |
if moments['m00'] == 0: |
return 0.0 |
cx = moments['m10'] / moments['m00'] |
cy = moments['m01'] / moments['m00'] |
return float(cv2.matchShapes(largest_contour, cv2.flip(largest_contour, 1), 1, 0.0)) |
def _analyze_border(self, image: np.ndarray) -> float: |
"""Analyze border irregularity""" |
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) |
_, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) |
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
if not contours: |
return 0.0 |
largest_contour = max(contours, key=cv2.contourArea) |
perimeter = cv2.arcLength(largest_contour, True) |
area = cv2.contourArea(largest_contour) |
if area == 0: |
return 0.0 |
circularity = 4 * np.pi * area / (perimeter * perimeter) |
return 1 - circularity |
def _analyze_color(self, hsv: np.ndarray) -> float: |
"""Analyze color variation in the lesion""" |
return float(np.std(hsv[:,:,0])) |
def _estimate_diameter(self, image: np.ndarray) -> float: |
"""Estimate lesion diameter""" |
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) |
_, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) |
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
if not contours: |
return 0.0 |
largest_contour = max(contours, key=cv2.contourArea) |
_, _, w, h = cv2.boundingRect(largest_contour) |
return max(w, h) |
def _analyze_texture(self, lab: np.ndarray) -> float: |
"""Analyze texture patterns""" |
gray = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR) |
gray = cv2.cvtColor(gray, cv2.COLOR_BGR2GRAY) |
glcm = cv2.calcHist([gray], [0], None, [16], [0,256]) |
glcm = glcm.flatten() / glcm.sum() |
entropy = -np.sum(glcm * np.log2(glcm + 1e-7)) |
return float(entropy) |
def _analyze_vascularity(self, image: np.ndarray) -> float: |
"""Analyze vascular patterns""" |
red_channel = image[:,:,0] |
return float(np.percentile(red_channel, 95) - np.percentile(red_channel, 5)) |
class SkinDiagnosticSystem: |
def __init__(self, model_path: Optional[str] = None): |
self.classes = [ |
'Melanocytic nevi', |
'Melanoma', |
'Benign keratosis-like lesions', |
'Basal cell carcinoma', |
'Actinic keratoses', |
'Vascular lesions', |
'Dermatofibroma' |
] |
self.risk_levels = { |
'Melanoma': 'High', |
'Basal cell carcinoma': 'High', |
'Actinic keratoses': 'Moderate', |
'Vascular lesions': 'Low to Moderate', |
'Benign keratosis-like lesions': 'Low', |
'Melanocytic nevi': 'Low', |
'Dermatofibroma': 'Low' |
} |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
self.image_validator = ImageValidator() |
self.image_analyzer = AdvancedImageAnalysis() |
self.model = self._load_model(model_path) |
self.transform = self._get_transforms() |
self.medical_context = self._load_medical_context() |
def _load_model(self, model_path: Optional[str]) -> nn.Module: |
"""Load model with checkpointing support""" |
try: |
model = EfficientNet.from_pretrained('efficientnet-b4') |
num_ftrs = model._fc.in_features |
model._fc = nn.Sequential( |
nn.Linear(num_ftrs, 512), |
nn.ReLU(), |
nn.Dropout(0.2), |
nn.Linear(512, len(self.classes)) |
) |
if model_path and os.path.exists(model_path): |
logger.info(f"Loading model checkpoint from {model_path}") |
checkpoint = torch.load(model_path, map_location=self.device) |
model.load_state_dict(checkpoint['model_state_dict']) |
logger.info(f"Model checkpoint loaded. Epoch: {checkpoint['epoch']}") |
model = model.to(self.device) |
model.eval() |
return model |
except Exception as e: |
logger.error(f"Error loading model: {str(e)}") |
raise |
def _get_transforms(self) -> transforms.Compose: |
"""Get image transformations""" |
return transforms.Compose([ |
transforms.Resize((224, 224)), |
transforms.ToTensor(), |
transforms.Normalize(mean=[0.485, 0.456, 0.406], |
std=[0.229, 0.224, 0.225]) |
]) |
def _load_medical_context(self) -> Dict[str, Any]: |
"""Load medical context and warnings""" |
return { |
'Melanoma': { |
'description': 'A serious form of skin cancer that begins in melanocytes.', |
'warning': 'URGENT: Immediate medical attention required. This is a potentially serious condition.', |
'risk_factors': [ |
'UV exposure', |
'Fair skin', |
'Family history', |
'Multiple moles' |
], |
'follow_up': 'Immediate dermatologist consultation required' |
}, |
'Basal cell carcinoma': { |
'description': 'The most common type of skin cancer.', |
'warning': 'Medical attention required. While typically slow-growing, treatment is necessary.', |
'risk_factors': [ |
'Sun exposure', |
'Fair skin', |
'Age over 50', |
'Prior radiation therapy' |
], |
'follow_up': 'Schedule dermatologist appointment within 1-2 weeks' |
}, |
} |
def save_checkpoint(self, epoch: int, optimizer: torch.optim.Optimizer, loss: float) -> None: |
"""Save model checkpoint""" |
checkpoint_dir = Path('checkpoints') |
checkpoint_dir.mkdir(exist_ok=True) |
checkpoint_path = checkpoint_dir / f'model_checkpoint_epoch_{epoch}.pth' |
torch.save({ |
'epoch': epoch, |
'model_state_dict': self.model.state_dict(), |
'optimizer_state_dict': optimizer.state_dict(), |
'loss': loss, |
}, checkpoint_path) |
logger.info(f"Checkpoint saved: {checkpoint_path}") |
def analyze_image(self, image: np.ndarray) -> Dict[str, Any]: |
"""Main analysis function with validation and advanced analysis""" |
try: |
is_valid, validation_message = self.image_validator.validate_image(image) |
if not is_valid: |
return {'error': validation_message} |
pil_image = Image.fromarray(image) |
img_tensor = self.transform(pil_image).unsqueeze(0).to(self.device) |
with torch.no_grad(): |
outputs = self.model(img_tensor) |
probs = torch.nn.functional.softmax(outputs, dim=1) |
pred_prob, pred_idx = torch.max(probs, 1) |
condition = self.classes[pred_idx] |
confidence = pred_prob.item() * 100 |
analysis_results = self.image_analyzer.analyze_lesion(image) |
medical_info = self.medical_context.get(condition, {}) |
response = { |
'condition': condition, |
'confidence': confidence, |
'risk_level': self.risk_levels.get(condition, 'Unknown'), |
'analysis': analysis_results, |
'medical_context': medical_info, |
'warning': medical_info.get('warning', ''), |
'timestamp': datetime.now().isoformat() |
} |
logger.info(f"Analysis completed for condition: {condition} (confidence: {confidence:.2f}%)") |
return response |
except Exception as e: |
logger.error(f"Error in image analysis: {str(e)}") |
return {'error': 'Analysis failed. Please try again.'} |
def create_gradio_interface(): |
system = SkinDiagnosticSystem() |
translation_models = { |
'hi': ('Helsinki-NLP/opus-mt-en-hi', MarianTokenizer, MarianMTModel), |
'ta': ('Helsinki-NLP/opus-mt-en-ta', MarianTokenizer, MarianMTModel), |
'te': ('Helsinki-NLP/opus-mt-en-te', MarianTokenizer, MarianMTModel), |
'bn': ('Helsinki-NLP/opus-mt-en-bn', MarianTokenizer, MarianMTModel), |
'mr': ('Helsinki-NLP/opus-mt-en-mr', MarianTokenizer, MarianMTModel), |
'pa': ('Helsinki-NLP/opus-mt-en-pa', MarianTokenizer, MarianMTModel), |
'gu': ('Helsinki-NLP/opus-mt-en-gu', MarianTokenizer, MarianMTModel), |
'kn': ('Helsinki-NLP/opus-mt-en-kn', MarianTokenizer, MarianMTModel), |
'ml': ('Helsinki-NLP/opus-mt-en-ml', MarianTokenizer, MarianMTModel), |
} |
def process_image(image, language, email=None): |
result = system.analyze_image(image) |
if 'error' in result: |
return f"Error: {result['error']}" |
output = "ANALYSIS RESULTS\n" + "="*50 + "\n\n" |
output += f"Detected Condition: {result['condition']}\n" |
output += f"Confidence: {result['confidence']:.2f}%\n" |
output += f"Risk Level: {result['risk_level']}\n\n" |
if result['warning']: |
output += f"⚠️ WARNING ⚠️\n{result['warning']}\n\n" |
output += "Detailed Analysis:\n" + "-"*20 + "\n" |
for metric, value in result['analysis'].items(): |
output += f"{metric}: {value:.2f}\n" |
if 'medical_context' in result and result['medical_context']: |
output += "\nMedical Context:\n" + "-"*20 + "\n" |
context = result['medical_context'] |
output += f"Description: {context.get('description', 'N/A')}\n" |
if 'risk_factors' in context: |
output += "\nRisk Factors:\n" |
for factor in context['risk_factors']: |
output += f"- {factor}\n" |
if 'follow_up' in context: |
output += f"\nRecommended Follow-up:\n{context['follow_up']}\n" |
output += f"\nAnalysis Timestamp: {result['timestamp']}\n" |
output += "\n" + "="*50 + "\n" |
output += "DISCLAIMER: This analysis is for informational purposes only and should not replace professional medical advice. Please consult a qualified healthcare provider for proper diagnosis and treatment." |
if language != 'en': |
model_name, tokenizer_class, model_class = translation_models[language] |
tokenizer = tokenizer_class.from_pretrained(model_name) |
model = model_class.from_pretrained(model_name) |
inputs = tokenizer(output, return_tensors="pt", padding=True, truncation=True) |
translated = model.generate(**inputs) |
translated_output = tokenizer.decode(translated[0], skip_special_tokens=True) |
else: |
translated_output = output |
if email: |
send_email(email, translated_output) |
return translated_output |
def send_email(to_email, message): |
from_email = "your_email@example.com" |
password = "your_password" |
msg = MIMEMultipart() |
msg['From'] = from_email |
msg['To'] = to_email |
msg['Subject'] = "Skin Lesion Analysis Results" |
msg.attach(MIMEText(message, 'plain')) |
server = smtplib.SMTP('smtp.example.com', 587) |
server.starttls() |
server.login(from_email, password) |
server.sendmail(from_email, to_email, msg.as_string()) |
server.quit() |
iface = gr.Interface( |
fn=process_image, |
inputs=[ |
gr.Image(type="numpy", label="Upload Skin Image"), |
gr.Dropdown(choices=["en", "hi", "ta", "te", "bn", "mr", "pa", "gu", "kn", "ml"], label="Select Language"), |
gr.Textbox(label="Email (optional)", placeholder="Enter your email to receive results") |
], |
outputs=[ |
gr.Textbox(label="Analysis Results", lines=20) |
], |
title="Advanced Skin Lesion Analysis System", |
description=""" |
This system analyzes skin lesions using advanced computer vision and deep learning techniques. |
Key Features: |
- Lesion classification based on the HAM10000 dataset |
- Advanced image quality validation |
- Detailed analysis of lesion characteristics |
- Medical context and risk assessment |
- Option to receive results via email |
Important: This tool is for educational purposes only and should not replace professional medical diagnosis. |
""", |
examples=[ |
["example_melanoma.jpg", "en", ""], |
["example_nevus.jpg", "hi", ""], |
["example_bcc.jpg", "ta", ""] |
], |
analytics_enabled=False, |
) |
return iface |
iface = create_gradio_interface() |
iface.launch( |
server_name="", |
server_port=7860, |
share=True, |
) |