NeuroNest / utils /improved_contrast_analyzer.py
lolout1's picture
Initial deployment to Hugging Face
6524e7a
# Create an improved contrast detection file: utils/improved_contrast_analyzer.py
import numpy as np
import cv2
import colorsys
from scipy import ndimage, spatial
from sklearn.cluster import DBSCAN
class ImprovedContrastAnalyzer:
"""
Advanced contrast analyzer focused on Alzheimer's-friendly environments
"""
def __init__(self, wcag_threshold=4.5):
self.wcag_threshold = wcag_threshold
# ADE20K class mappings for important objects
self.important_classes = {
'floor': [3, 4], # floor, wood floor
'wall': [0, 1], # wall, building
'ceiling': [5], # ceiling
'sofa': [10], # sofa
'chair': [19], # chair
'table': [15], # table
'door': [25], # door
'window': [8], # window
'stairs': [53], # stairs
'bed': [7], # bed
}
# Priority relationships (high priority = more important for safety)
self.priority_relationships = {
('floor', 'sofa'): 'high',
('floor', 'chair'): 'high',
('floor', 'table'): 'high',
('wall', 'sofa'): 'medium',
('wall', 'chair'): 'medium',
('wall', 'door'): 'high',
('floor', 'stairs'): 'critical',
('floor', 'bed'): 'medium',
('wall', 'window'): 'low',
('ceiling', 'wall'): 'low',
}
def get_object_category(self, class_id):
"""Map segmentation class to object category"""
for category, class_ids in self.important_classes.items():
if class_id in class_ids:
return category
return 'other'
def calculate_wcag_contrast(self, color1, color2):
"""Calculate WCAG contrast ratio"""
def relative_luminance(color):
rgb = [c / 255.0 for c in color]
return sum(c / 12.92 if c <= 0.03928 else ((c + 0.055) / 1.055) ** 2.4
for c in rgb) * [0.2126, 0.7152, 0.0722][i] for i, c in enumerate(rgb))
lum1 = sum(self.relative_luminance_component(color1))
lum2 = sum(self.relative_luminance_component(color2))
lighter = max(lum1, lum2)
darker = min(lum1, lum2)
return (lighter + 0.05) / (darker + 0.05)
def relative_luminance_component(self, color):
"""Calculate relative luminance components"""
rgb = [c / 255.0 for c in color]
components = []
factors = [0.2126, 0.7152, 0.0722]
for i, c in enumerate(rgb):
if c <= 0.03928:
components.append((c / 12.92) * factors[i])
else:
components.append(((c + 0.055) / 1.055) ** 2.4 * factors[i])
return components
def calculate_perceptual_contrast(self, color1, color2):
"""Calculate perceptual contrast including hue and saturation differences"""
# Convert to HSV for better perceptual analysis
hsv1 = cv2.cvtColor(np.uint8([[color1]]), cv2.COLOR_RGB2HSV)[0][0] / 255.0
hsv2 = cv2.cvtColor(np.uint8([[color2]]), cv2.COLOR_RGB2HSV)[0][0] / 255.0
# Hue difference (circular)
hue_diff = abs(hsv1[0] - hsv2[0])
if hue_diff > 0.5:
hue_diff = 1 - hue_diff
# Saturation difference
sat_diff = abs(hsv1[1] - hsv2[1])
# Value (brightness) difference
val_diff = abs(hsv1[2] - hsv2[2])
# Combined perceptual score (0-1, higher is more different)
perceptual_contrast = np.sqrt(hue_diff**2 + sat_diff**2 + val_diff**2) / np.sqrt(3)
return perceptual_contrast
def find_clean_boundaries(self, mask1, mask2, min_boundary_length=50):
"""Find clean boundaries between two segments"""
# Dilate both masks slightly
kernel = np.ones((3, 3), np.uint8)
dilated1 = cv2.dilate(mask1.astype(np.uint8), kernel, iterations=1)
dilated2 = cv2.dilate(mask2.astype(np.uint8), kernel, iterations=1)
# Find intersection (boundary area)
boundary = (dilated1 & dilated2).astype(bool)
# Remove small disconnected boundary pieces
labeled_boundary = ndimage.label(boundary)[0]
for region_id in range(1, labeled_boundary.max() + 1):
region_mask = labeled_boundary == region_id
if np.sum(region_mask) < min_boundary_length:
boundary[region_mask] = False
return boundary
def get_representative_colors(self, image, mask, n_samples=1000):
"""Get representative colors from a masked region using clustering"""
if not np.any(mask):
return np.array([0, 0, 0])
# Sample pixels from the mask
y_coords, x_coords = np.where(mask)
if len(y_coords) > n_samples:
indices = np.random.choice(len(y_coords), n_samples, replace=False)
y_coords = y_coords[indices]
x_coords = x_coords[indices]
colors = image[y_coords, x_coords]
# Use DBSCAN clustering to find dominant colors
if len(colors) > 10:
clustering = DBSCAN(eps=30, min_samples=5).fit(colors)
labels = clustering.labels_
# Get the largest cluster
unique_labels, counts = np.unique(labels[labels >= 0], return_counts=True)
if len(unique_labels) > 0:
dominant_label = unique_labels[np.argmax(counts)]
dominant_colors = colors[labels == dominant_label]
return np.mean(dominant_colors, axis=0).astype(int)
# Fallback to mean color
return np.mean(colors, axis=0).astype(int)
def analyze_improved_contrast(self, image, segmentation):
"""
Perform improved contrast analysis focused on important relationships
"""
h, w = segmentation.shape
results = {
'critical_issues': [],
'high_priority_issues': [],
'medium_priority_issues': [],
'statistics': {},
'visualization': image.copy()
}
# Get unique segments and their categories
unique_segments = np.unique(segmentation)
segment_categories = {}
segment_colors = {}
for seg_id in unique_segments:
if seg_id == 0: # Skip background
continue
mask = segmentation == seg_id
category = self.get_object_category(seg_id)
segment_categories[seg_id] = category
segment_colors[seg_id] = self.get_representative_colors(image, mask)
# Analyze important relationships
total_issues = 0
critical_count = 0
high_count = 0
medium_count = 0
for i, seg_id1 in enumerate(unique_segments):
if seg_id1 == 0:
continue
category1 = segment_categories.get(seg_id1, 'other')
if category1 == 'other':
continue
for seg_id2 in unique_segments[i+1:]:
if seg_id2 == 0:
continue
category2 = segment_categories.get(seg_id2, 'other')
if category2 == 'other':
continue
# Check if this is an important relationship
relationship = tuple(sorted([category1, category2]))
priority = self.priority_relationships.get(relationship)
if priority is None:
continue
# Check if segments are adjacent
mask1 = segmentation == seg_id1
mask2 = segmentation == seg_id2
boundary = self.find_clean_boundaries(mask1, mask2)
if not np.any(boundary):
continue
# Calculate contrasts
color1 = segment_colors[seg_id1]
color2 = segment_colors[seg_id2]
wcag_contrast = self.calculate_wcag_contrast(color1, color2)
perceptual_contrast = self.calculate_perceptual_contrast(color1, color2)
# Determine if there's an issue
wcag_issue = wcag_contrast < self.wcag_threshold
perceptual_issue = perceptual_contrast < 0.3 # Threshold for perceptual difference
if wcag_issue or perceptual_issue:
issue = {
'categories': (category1, category2),
'segment_ids': (seg_id1, seg_id2),
'wcag_contrast': wcag_contrast,
'perceptual_contrast': perceptual_contrast,
'boundary_pixels': np.sum(boundary),
'priority': priority
}
# Color-code the boundary based on priority
if priority == 'critical':
results['critical_issues'].append(issue)
results['visualization'][boundary] = [255, 0, 0] # Red
critical_count += 1
elif priority == 'high':
results['high_priority_issues'].append(issue)
results['visualization'][boundary] = [255, 128, 0] # Orange
high_count += 1
elif priority == 'medium':
results['medium_priority_issues'].append(issue)
results['visualization'][boundary] = [255, 255, 0] # Yellow
medium_count += 1
total_issues += 1
# Calculate statistics
results['statistics'] = {
'total_issues': total_issues,
'critical_issues': critical_count,
'high_priority_issues': high_count,
'medium_priority_issues': medium_count,
'segments_analyzed': len([cat for cat in segment_categories.values() if cat != 'other'])
}
return results
# Update your contrast detection imports and usage
class PrioritizedContrastDetector:
"""Wrapper for the improved contrast analyzer"""
def __init__(self, threshold=4.5):
self.analyzer = ImprovedContrastAnalyzer(wcag_threshold=threshold)
def analyze(self, image, segmentation, threshold, highlight_color=(255, 0, 0)):
"""Analyze with improved logic"""
results = self.analyzer.analyze_improved_contrast(image, segmentation)
# Convert to format expected by original interface
contrast_image = results['visualization']
# Create a simple problem areas mask for compatibility
problem_areas = np.any([
contrast_image[:, :, 0] == 255, # Any red channel highlighting
], axis=0)
# Format statistics
stats = results['statistics'].copy()
stats['threshold'] = threshold
stats['problem_count'] = stats['total_issues']
# Add detailed breakdown
if results['critical_issues']:
stats['critical_details'] = [
f"{issue['categories'][0]}-{issue['categories'][1]}: WCAG {issue['wcag_contrast']:.1f}:1"
for issue in results['critical_issues']
]
if results['high_priority_issues']:
stats['high_priority_details'] = [
f"{issue['categories'][0]}-{issue['categories'][1]}: WCAG {issue['wcag_contrast']:.1f}:1"
for issue in results['high_priority_issues']
]
return contrast_image, problem_areas, stats