| | import numpy as np
|
| | import cv2
|
| | from sklearn.cluster import KMeans
|
| | from sklearn.preprocessing import LabelEncoder
|
| | import matplotlib.pyplot as plt
|
| | from typing import Tuple, List, Dict, Optional
|
| | from dataclasses import dataclass
|
| | from enum import Enum
|
| | import seaborn as sns
|
| |
|
| | class ColorClassificationMethod(Enum):
|
| | """Different methods for classifying cell colors."""
|
| | DOMINANT_COLOR = "dominant_color"
|
| | AVERAGE_COLOR = "average_color"
|
| | HISTOGRAM_BINS = "histogram_bins"
|
| | HSV_QUANTIZATION = "hsv_quantization"
|
| |
|
| | @dataclass
|
| | class GridCell:
|
| | """Represents a single grid cell with its properties."""
|
| | row: int
|
| | col: int
|
| | average_color: np.ndarray
|
| | dominant_color: np.ndarray
|
| | brightness: float
|
| | saturation: float
|
| | hue: float
|
| | color_category: int
|
| | pixel_data: np.ndarray
|
| |
|
| | class ImageGridAnalyzer:
|
| | """
|
| | Analyzes images by dividing them into grids and classifying each cell's color properties.
|
| | Uses vectorized NumPy operations for high performance.
|
| | """
|
| |
|
| | def __init__(self, grid_size: Tuple[int, int] = (32, 32),
|
| | classification_method: ColorClassificationMethod = ColorClassificationMethod.DOMINANT_COLOR,
|
| | n_color_categories: int = 16):
|
| | """
|
| | Initialize the grid analyzer.
|
| |
|
| | Args:
|
| | grid_size: (rows, cols) for the grid division
|
| | classification_method: Method to classify cell colors
|
| | n_color_categories: Number of color categories for classification
|
| | """
|
| | self.grid_size = grid_size
|
| | self.classification_method = classification_method
|
| | self.n_color_categories = n_color_categories
|
| | self.color_classifier = None
|
| | self.category_colors = None
|
| |
|
| | def divide_image_into_grid(self, image: np.ndarray) -> Tuple[np.ndarray, Tuple[int, int]]:
|
| | """
|
| | Divide image into a grid using vectorized operations.
|
| |
|
| | Args:
|
| | image: Input image (H, W, C)
|
| |
|
| | Returns:
|
| | Grid of cells (grid_rows, grid_cols, tile_height, tile_width, channels)
|
| | Tuple of (tile_height, tile_width)
|
| | """
|
| | h, w, c = image.shape
|
| | grid_rows, grid_cols = self.grid_size
|
| |
|
| |
|
| | tile_h = h // grid_rows
|
| | tile_w = w // grid_cols
|
| |
|
| |
|
| | adjusted_h = tile_h * grid_rows
|
| | adjusted_w = tile_w * grid_cols
|
| | image = image[:adjusted_h, :adjusted_w]
|
| |
|
| |
|
| |
|
| | grid = image.reshape(grid_rows, tile_h, grid_cols, tile_w, c)
|
| | grid = grid.transpose(0, 2, 1, 3, 4)
|
| |
|
| | return grid, (tile_h, tile_w)
|
| |
|
| | def analyze_grid_colors_vectorized(self, grid: np.ndarray) -> Dict[str, np.ndarray]:
|
| | """
|
| | Analyze color properties of all grid cells using vectorized operations.
|
| |
|
| | Args:
|
| | grid: Grid of cells (grid_rows, grid_cols, tile_h, tile_w, c)
|
| |
|
| | Returns:
|
| | Dictionary containing vectorized analysis results
|
| | """
|
| | grid_rows, grid_cols, tile_h, tile_w, c = grid.shape
|
| |
|
| |
|
| | cells_flat = grid.reshape(grid_rows * grid_cols, tile_h * tile_w, c)
|
| |
|
| |
|
| | average_colors = np.mean(cells_flat, axis=1)
|
| |
|
| |
|
| | dominant_colors = self._calculate_dominant_colors_vectorized(cells_flat)
|
| |
|
| |
|
| | hsv_averages = self._rgb_to_hsv_vectorized(average_colors)
|
| |
|
| |
|
| | brightness = hsv_averages[:, 2]
|
| |
|
| |
|
| | saturation = hsv_averages[:, 1]
|
| |
|
| |
|
| | hue = hsv_averages[:, 0]
|
| |
|
| |
|
| | results = {
|
| | 'average_colors': average_colors.reshape(grid_rows, grid_cols, c),
|
| | 'dominant_colors': dominant_colors.reshape(grid_rows, grid_cols, c),
|
| | 'brightness': brightness.reshape(grid_rows, grid_cols),
|
| | 'saturation': saturation.reshape(grid_rows, grid_cols),
|
| | 'hue': hue.reshape(grid_rows, grid_cols),
|
| | 'cells_data': grid
|
| | }
|
| |
|
| | return results
|
| |
|
| | def _calculate_dominant_colors_vectorized(self, cells_flat: np.ndarray) -> np.ndarray:
|
| | """
|
| | Calculate dominant color for each cell using vectorized operations.
|
| |
|
| | Args:
|
| | cells_flat: Flattened cells (total_cells, pixels_per_cell, channels)
|
| |
|
| | Returns:
|
| | Dominant colors for all cells (total_cells, channels)
|
| | """
|
| | import warnings
|
| |
|
| | total_cells, pixels_per_cell, c = cells_flat.shape
|
| | dominant_colors = np.zeros((total_cells, c))
|
| |
|
| |
|
| | batch_size = 100
|
| | for i in range(0, total_cells, batch_size):
|
| | end_idx = min(i + batch_size, total_cells)
|
| | batch = cells_flat[i:end_idx]
|
| |
|
| | for j, cell_pixels in enumerate(batch):
|
| |
|
| | unique_pixels = np.unique(cell_pixels, axis=0)
|
| |
|
| | if len(unique_pixels) >= 3 and pixels_per_cell > 100:
|
| |
|
| | with warnings.catch_warnings():
|
| | warnings.filterwarnings("ignore", category=UserWarning)
|
| | warnings.filterwarnings("ignore", message=".*ConvergenceWarning.*")
|
| |
|
| | kmeans = KMeans(n_clusters=min(3, len(unique_pixels)),
|
| | random_state=42, n_init=5)
|
| | labels = kmeans.fit_predict(cell_pixels)
|
| |
|
| | unique_labels, counts = np.unique(labels, return_counts=True)
|
| | dominant_idx = unique_labels[np.argmax(counts)]
|
| | dominant_colors[i + j] = kmeans.cluster_centers_[dominant_idx]
|
| | elif len(unique_pixels) >= 2:
|
| |
|
| | unique_colors, counts = np.unique(cell_pixels, axis=0, return_counts=True)
|
| | dominant_colors[i + j] = unique_colors[np.argmax(counts)]
|
| | else:
|
| |
|
| | dominant_colors[i + j] = np.mean(cell_pixels, axis=0)
|
| |
|
| | return dominant_colors
|
| |
|
| | def _rgb_to_hsv_vectorized(self, rgb_colors: np.ndarray) -> np.ndarray:
|
| | """
|
| | Convert RGB colors to HSV using vectorized operations.
|
| |
|
| | Args:
|
| | rgb_colors: RGB colors (N, 3)
|
| |
|
| | Returns:
|
| | HSV colors (N, 3)
|
| | """
|
| |
|
| | rgb_normalized = rgb_colors / 255.0
|
| |
|
| |
|
| | dummy_img = rgb_normalized.reshape(-1, 1, 3).astype(np.float32)
|
| | hsv_img = cv2.cvtColor(dummy_img, cv2.COLOR_RGB2HSV)
|
| | hsv_colors = hsv_img.reshape(-1, 3)
|
| |
|
| | return hsv_colors
|
| |
|
| | def classify_colors(self, color_data: Dict[str, np.ndarray]) -> np.ndarray:
|
| | """
|
| | Classify each grid cell into color categories.
|
| |
|
| | Args:
|
| | color_data: Dictionary containing color analysis results
|
| |
|
| | Returns:
|
| | Color categories for each grid cell (grid_rows, grid_cols)
|
| | """
|
| | import warnings
|
| |
|
| | if self.classification_method == ColorClassificationMethod.AVERAGE_COLOR:
|
| | features = color_data['average_colors']
|
| | elif self.classification_method == ColorClassificationMethod.DOMINANT_COLOR:
|
| | features = color_data['dominant_colors']
|
| | elif self.classification_method == ColorClassificationMethod.HSV_QUANTIZATION:
|
| |
|
| | h = color_data['hue']
|
| | s = color_data['saturation']
|
| | v = color_data['brightness']
|
| | features = np.stack([h, s, v], axis=-1)
|
| | else:
|
| | features = color_data['average_colors']
|
| |
|
| |
|
| | grid_rows, grid_cols = features.shape[:2]
|
| | features_flat = features.reshape(-1, features.shape[-1])
|
| |
|
| |
|
| | unique_features = np.unique(features_flat, axis=0)
|
| | effective_clusters = min(self.n_color_categories, len(unique_features))
|
| |
|
| | if effective_clusters < 2:
|
| |
|
| | print(f"Warning: Only {len(unique_features)} unique colors found. Using simple classification.")
|
| | categories = np.zeros(len(features_flat), dtype=int)
|
| | categories_grid = categories.reshape(grid_rows, grid_cols)
|
| | self.category_colors = unique_features[:1] if len(unique_features) > 0 else np.array([[128, 128, 128]])
|
| | return categories_grid
|
| |
|
| |
|
| | with warnings.catch_warnings():
|
| | warnings.filterwarnings("ignore", category=UserWarning)
|
| | warnings.filterwarnings("ignore", message=".*ConvergenceWarning.*")
|
| |
|
| | self.color_classifier = KMeans(n_clusters=effective_clusters,
|
| | random_state=42, n_init=10)
|
| | categories = self.color_classifier.fit_predict(features_flat)
|
| |
|
| |
|
| | self.category_colors = self.color_classifier.cluster_centers_
|
| |
|
| |
|
| | categories_grid = categories.reshape(grid_rows, grid_cols)
|
| |
|
| | return categories_grid
|
| |
|
| | def apply_thresholding(self, color_data: Dict[str, np.ndarray],
|
| | brightness_threshold: float = 0.5,
|
| | saturation_threshold: float = 0.3) -> Dict[str, np.ndarray]:
|
| | """
|
| | Apply thresholding to create binary masks for different criteria.
|
| |
|
| | Args:
|
| | color_data: Color analysis results
|
| | brightness_threshold: Threshold for bright/dark classification
|
| | saturation_threshold: Threshold for saturated/desaturated classification
|
| |
|
| | Returns:
|
| | Dictionary containing various threshold masks
|
| | """
|
| | brightness = color_data['brightness']
|
| | saturation = color_data['saturation']
|
| |
|
| |
|
| | brightness_norm = brightness / 255.0 if brightness.max() > 1.0 else brightness
|
| | saturation_norm = saturation / 255.0 if saturation.max() > 1.0 else saturation
|
| |
|
| | thresholds = {
|
| | 'bright_mask': brightness_norm > brightness_threshold,
|
| | 'dark_mask': brightness_norm <= brightness_threshold,
|
| | 'saturated_mask': saturation_norm > saturation_threshold,
|
| | 'desaturated_mask': saturation_norm <= saturation_threshold,
|
| | 'bright_saturated': (brightness_norm > brightness_threshold) &
|
| | (saturation_norm > saturation_threshold),
|
| | 'dark_saturated': (brightness_norm <= brightness_threshold) &
|
| | (saturation_norm > saturation_threshold)
|
| | }
|
| |
|
| | return thresholds
|
| |
|
| | def analyze_image_complete(self, image: np.ndarray) -> Dict:
|
| | """
|
| | Complete analysis pipeline for an image.
|
| |
|
| | Args:
|
| | image: Input image (H, W, C)
|
| |
|
| | Returns:
|
| | Complete analysis results
|
| | """
|
| | print(f"Analyzing image with {self.grid_size[0]}x{self.grid_size[1]} grid...")
|
| |
|
| |
|
| | grid, tile_size = self.divide_image_into_grid(image)
|
| | print(f"Created grid with tile size: {tile_size}")
|
| |
|
| |
|
| | color_data = self.analyze_grid_colors_vectorized(grid)
|
| | print("Completed color analysis")
|
| |
|
| |
|
| | color_categories = self.classify_colors(color_data)
|
| | print(f"Classified into {self.n_color_categories} color categories")
|
| |
|
| |
|
| | thresholds = self.apply_thresholding(color_data)
|
| | print("Applied thresholding")
|
| |
|
| |
|
| | results = {
|
| | 'grid': grid,
|
| | 'tile_size': tile_size,
|
| | 'color_data': color_data,
|
| | 'color_categories': color_categories,
|
| | 'thresholds': thresholds,
|
| | 'category_colors': self.category_colors
|
| | }
|
| |
|
| | return results
|
| |
|
| | def visualize_analysis(self, results: Dict, original_image: np.ndarray):
|
| | """
|
| | Create comprehensive visualizations of the analysis results.
|
| |
|
| | Args:
|
| | results: Analysis results from analyze_image_complete
|
| | original_image: Original input image
|
| | """
|
| | fig, axes = plt.subplots(2, 4, figsize=(20, 10))
|
| |
|
| |
|
| | axes[0, 0].imshow(original_image)
|
| | axes[0, 0].set_title('Original Image')
|
| | axes[0, 0].axis('off')
|
| |
|
| |
|
| | avg_colors = results['color_data']['average_colors'].astype(np.uint8)
|
| | axes[0, 1].imshow(avg_colors)
|
| | axes[0, 1].set_title('Average Colors per Cell')
|
| | axes[0, 1].axis('off')
|
| |
|
| |
|
| | dom_colors = results['color_data']['dominant_colors'].astype(np.uint8)
|
| | axes[0, 2].imshow(dom_colors)
|
| | axes[0, 2].set_title('Dominant Colors per Cell')
|
| | axes[0, 2].axis('off')
|
| |
|
| |
|
| | categories = results['color_categories']
|
| | im_cat = axes[0, 3].imshow(categories, cmap='tab20')
|
| | axes[0, 3].set_title(f'Color Categories ({self.n_color_categories} classes)')
|
| | axes[0, 3].axis('off')
|
| | plt.colorbar(im_cat, ax=axes[0, 3])
|
| |
|
| |
|
| | brightness = results['color_data']['brightness']
|
| | im_bright = axes[1, 0].imshow(brightness, cmap='gray')
|
| | axes[1, 0].set_title('Brightness Values')
|
| | axes[1, 0].axis('off')
|
| | plt.colorbar(im_bright, ax=axes[1, 0])
|
| |
|
| |
|
| | saturation = results['color_data']['saturation']
|
| | im_sat = axes[1, 1].imshow(saturation, cmap='viridis')
|
| | axes[1, 1].set_title('Saturation Values')
|
| | axes[1, 1].axis('off')
|
| | plt.colorbar(im_sat, ax=axes[1, 1])
|
| |
|
| |
|
| | axes[1, 2].imshow(results['thresholds']['bright_mask'], cmap='gray')
|
| | axes[1, 2].set_title('Bright Areas (Threshold)')
|
| | axes[1, 2].axis('off')
|
| |
|
| |
|
| | axes[1, 3].imshow(results['thresholds']['saturated_mask'], cmap='gray')
|
| | axes[1, 3].set_title('Saturated Areas (Threshold)')
|
| | axes[1, 3].axis('off')
|
| |
|
| | plt.tight_layout()
|
| | plt.show()
|
| |
|
| |
|
| | self._visualize_color_palette(results['category_colors'])
|
| |
|
| | def _visualize_color_palette(self, category_colors: np.ndarray):
|
| | """
|
| | Visualize the color category palette.
|
| |
|
| | Args:
|
| | category_colors: Color palette (n_categories, channels)
|
| | """
|
| | if category_colors is None:
|
| | return
|
| |
|
| | fig, ax = plt.subplots(1, 1, figsize=(12, 2))
|
| |
|
| |
|
| | colors = category_colors.copy()
|
| | if colors.max() > 1.0:
|
| | colors = colors / 255.0
|
| |
|
| |
|
| | palette = colors.reshape(1, -1, 3)
|
| | ax.imshow(palette, aspect='auto')
|
| | ax.set_xlim(0, len(colors))
|
| | ax.set_ylim(0, 1)
|
| | ax.set_xticks(range(len(colors)))
|
| | ax.set_xticklabels([f'Cat {i}' for i in range(len(colors))])
|
| | ax.set_title(f'Color Category Palette ({len(colors)} categories)')
|
| | ax.set_ylabel('Color Categories')
|
| |
|
| | plt.tight_layout()
|
| | plt.show()
|
| |
|
| | def get_performance_stats(self, results: Dict) -> Dict:
|
| | """
|
| | Calculate performance and analysis statistics.
|
| |
|
| | Args:
|
| | results: Analysis results
|
| |
|
| | Returns:
|
| | Dictionary containing statistics
|
| | """
|
| | grid_shape = results['color_categories'].shape
|
| | total_cells = np.prod(grid_shape)
|
| |
|
| |
|
| | unique_categories = len(np.unique(results['color_categories']))
|
| |
|
| |
|
| | brightness = results['color_data']['brightness']
|
| |
|
| |
|
| | saturation = results['color_data']['saturation']
|
| |
|
| | stats = {
|
| | 'grid_size': f"{grid_shape[0]}x{grid_shape[1]}",
|
| | 'total_cells': total_cells,
|
| | 'unique_color_categories': unique_categories,
|
| | 'category_utilization': unique_categories / self.n_color_categories,
|
| | 'avg_brightness': np.mean(brightness),
|
| | 'brightness_std': np.std(brightness),
|
| | 'avg_saturation': np.mean(saturation),
|
| | 'saturation_std': np.std(saturation),
|
| | 'bright_cells_percent': np.mean(results['thresholds']['bright_mask']) * 100,
|
| | 'saturated_cells_percent': np.mean(results['thresholds']['saturated_mask']) * 100
|
| | }
|
| |
|
| | return stats
|
| |
|
| |
|
| | def main():
|
| | """
|
| | Example usage of the ImageGridAnalyzer.
|
| | """
|
| |
|
| | def create_test_image():
|
| |
|
| | img = np.zeros((256, 256, 3), dtype=np.uint8)
|
| |
|
| |
|
| | for i in range(256):
|
| | for j in range(256):
|
| | img[i, j, 0] = i
|
| | img[i, j, 1] = j
|
| | img[i, j, 2] = (i + j) % 255
|
| |
|
| | return img
|
| |
|
| |
|
| | analyzer = ImageGridAnalyzer(
|
| | grid_size=(32, 32),
|
| | classification_method=ColorClassificationMethod.DOMINANT_COLOR,
|
| | n_color_categories=16
|
| | )
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | test_image = cv2.imread('processed_quantized.jpg')
|
| | test_image = cv2.cvtColor(test_image, cv2.COLOR_BGR2RGB)
|
| |
|
| | print("Starting analysis...")
|
| |
|
| |
|
| | results = analyzer.analyze_image_complete(test_image)
|
| |
|
| |
|
| | stats = analyzer.get_performance_stats(results)
|
| |
|
| | print("\n=== Analysis Statistics ===")
|
| | for key, value in stats.items():
|
| | print(f"{key}: {value}")
|
| |
|
| |
|
| |
|
| |
|
| | return results, analyzer
|
| |
|
| | if __name__ == "__main__":
|
| | results, analyzer = main() |