| | """ |
| | Visualize predictions from trained model on local images folder |
| | """ |
| | import torch |
| | from torch.utils.data import Dataset |
| | import torchvision.transforms as transforms |
| | from split_model import SplitModel |
| | from PIL import Image, ImageDraw |
| | import numpy as np |
| | import glob |
| | import os |
| |
|
| | class LocalImageDataset(Dataset): |
| | """Dataset for loading images from a local folder""" |
| | def __init__(self, image_folder): |
| | self.image_paths = sorted(glob.glob(os.path.join(image_folder, "*.png")) + |
| | glob.glob(os.path.join(image_folder, "*.jpg")) + |
| | glob.glob(os.path.join(image_folder, "*.jpeg"))) |
| |
|
| | if len(self.image_paths) == 0: |
| | raise ValueError(f"No images found in {image_folder}") |
| |
|
| | self.transform = transforms.Compose([ |
| | transforms.Resize((960, 960)), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| | ]) |
| |
|
| | print(f"Found {len(self.image_paths)} images in {image_folder}") |
| |
|
| | def __len__(self): |
| | return len(self.image_paths) |
| |
|
| | def __getitem__(self, idx): |
| | image_path = self.image_paths[idx] |
| | image = Image.open(image_path).convert('RGB') |
| | image_transformed = self.transform(image) |
| | return image_transformed, image, image_path |
| |
|
| | def get_middle_of_groups(binary_array): |
| | """ |
| | Find groups of consecutive 1's and return only the middle index of each group. |
| | Example: [0,0,1,1,1,1,1,0,0,1,1,0] -> [0,0,0,0,1,0,0,0,1,0,0] |
| | """ |
| | result = np.zeros_like(binary_array) |
| | i = 0 |
| | n = len(binary_array) |
| |
|
| | while i < n: |
| | if binary_array[i] == 1: |
| | |
| | start = i |
| | while i < n and binary_array[i] == 1: |
| | i += 1 |
| | end = i - 1 |
| |
|
| | |
| | middle = (start + end) // 2 |
| | result[middle] = 1 |
| | else: |
| | i += 1 |
| |
|
| | return result |
| |
|
| | def visualize_prediction(model, dataset, idx, device, output_folder, threshold=0.5): |
| | """Visualize prediction for a single image (no ground truth)""" |
| | model.eval() |
| |
|
| | |
| | image_tensor, original_image, image_path = dataset[idx] |
| |
|
| | |
| | with torch.no_grad(): |
| | image_batch = image_tensor.unsqueeze(0).to(device) |
| | h_pred, v_pred = model(image_batch) |
| |
|
| | |
| | h_pred = h_pred.repeat_interleave(2, dim=1) |
| | v_pred = v_pred.repeat_interleave(2, dim=1) |
| |
|
| | h_pred = h_pred.squeeze(0).cpu() |
| | v_pred = v_pred.squeeze(0).cpu() |
| |
|
| | |
| | h_binary = (h_pred > threshold).float().numpy() |
| | v_binary = (v_pred > threshold).float().numpy() |
| |
|
| | |
| | h_binary_clean = get_middle_of_groups(h_binary) |
| | v_binary_clean = get_middle_of_groups(v_binary) |
| |
|
| | |
| | h_splits = h_binary_clean.sum() |
| | v_splits = v_binary_clean.sum() |
| | pred_rows = int(h_splits) + 1 |
| | pred_cols = int(v_splits) + 1 |
| |
|
| | print(f"\nImage {idx}: {os.path.basename(image_path)}") |
| | print(f" H Splits: {h_splits:.0f} | Pred Rows: {pred_rows}") |
| | print(f" V Splits: {v_splits:.0f} | Pred Cols: {pred_cols}") |
| | print(f" Total Cells: {pred_rows * pred_cols}") |
| | print(f" H pred range: [{h_pred.min():.3f}, {h_pred.max():.3f}]") |
| | print(f" V pred range: [{v_pred.min():.3f}, {v_pred.max():.3f}]") |
| |
|
| | |
| | W, H = original_image.size |
| |
|
| | |
| | zoom_factor = 1.5 |
| | W_zoomed = int(W * zoom_factor) |
| | H_zoomed = int(H * zoom_factor) |
| |
|
| | |
| | original_zoomed = original_image.resize((W_zoomed, H_zoomed), Image.LANCZOS) |
| |
|
| | |
| | info_height = 150 |
| | label_height = 60 |
| | padding = 40 |
| |
|
| | try: |
| | from PIL import ImageFont |
| | font_title = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 32) |
| | font_text = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 24) |
| | font_label = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 28) |
| | except: |
| | font_title = None |
| | font_text = None |
| | font_label = None |
| |
|
| | |
| | total_width = W_zoomed + padding * 2 |
| | total_height = info_height + (label_height + H_zoomed) * 3 + padding * 5 |
| | vis_image = Image.new('RGB', (total_width, total_height), 'white') |
| | draw = ImageDraw.Draw(vis_image) |
| |
|
| | |
| | info_x = padding |
| | info_y = padding |
| |
|
| | draw.rectangle([info_x, info_y, total_width - padding, info_y + info_height], |
| | fill='#e3f2fd', outline='#1565c0', width=4) |
| |
|
| | draw.text((info_x + 20, info_y + 20), "PREDICTED SPLITS", fill='#0d47a1', font=font_title) |
| | draw.text((info_x + 20, info_y + 70), f"Rows: {pred_rows}", fill='black', font=font_text) |
| | draw.text((info_x + 20, info_y + 105), f"Columns: {pred_cols}", fill='black', font=font_text) |
| | draw.text((info_x + 300, info_y + 70), f"H Splits: {int(h_splits)}", fill='#c62828', font=font_text) |
| | draw.text((info_x + 300, info_y + 105), f"V Splits: {int(v_splits)}", fill='#1565c0', font=font_text) |
| |
|
| | |
| | y_pos = info_height + padding * 2 |
| | draw.rectangle([padding, y_pos, total_width - padding, y_pos + label_height], |
| | fill='#f5f5f5', outline='#666666', width=2) |
| | draw.text((padding + 20, y_pos + 15), "Original Image", fill='#333333', font=font_label) |
| |
|
| | vis_image.paste(original_zoomed, (padding, y_pos + label_height)) |
| |
|
| | |
| | y_pos = info_height + padding * 3 + label_height + H_zoomed |
| | draw.rectangle([padding, y_pos, total_width - padding, y_pos + label_height], |
| | fill='#fff3e0', outline='#ff9800', width=2) |
| | draw.text((padding + 20, y_pos + 15), "Raw Model Predictions (All 1's)", fill='#e65100', font=font_label) |
| |
|
| | |
| | raw_pred_image = original_image.copy() |
| | draw_raw = ImageDraw.Draw(raw_pred_image) |
| |
|
| | |
| | for y in range(960): |
| | if h_binary[y] == 1: |
| | y_scaled = int(y * H / 960) |
| | draw_raw.line([(0, y_scaled), (W, y_scaled)], fill='#ff0000', width=2) |
| |
|
| | |
| | for x in range(960): |
| | if v_binary[x] == 1: |
| | x_scaled = int(x * W / 960) |
| | draw_raw.line([(x_scaled, 0), (x_scaled, H)], fill='#0000ff', width=2) |
| |
|
| | |
| | raw_pred_zoomed = raw_pred_image.resize((W_zoomed, H_zoomed), Image.LANCZOS) |
| | vis_image.paste(raw_pred_zoomed, (padding, y_pos + label_height)) |
| |
|
| | |
| | y_pos = info_height + padding * 4 + (label_height + H_zoomed) * 2 |
| | draw.rectangle([padding, y_pos, total_width - padding, y_pos + label_height], |
| | fill='#e3f2fd', outline='#1565c0', width=2) |
| | draw.text((padding + 20, y_pos + 15), "Cleaned Predictions (Middle Only)", fill='#0d47a1', font=font_label) |
| |
|
| | |
| | pred_image = original_image.copy() |
| | draw_pred = ImageDraw.Draw(pred_image) |
| |
|
| | |
| | for y in range(960): |
| | if h_binary_clean[y] == 1: |
| | y_scaled = int(y * H / 960) |
| | draw_pred.line([(0, y_scaled), (W, y_scaled)], fill='#ff0000', width=3) |
| |
|
| | |
| | for x in range(960): |
| | if v_binary_clean[x] == 1: |
| | x_scaled = int(x * W / 960) |
| | draw_pred.line([(x_scaled, 0), (x_scaled, H)], fill='#0000ff', width=3) |
| |
|
| | |
| | pred_zoomed = pred_image.resize((W_zoomed, H_zoomed), Image.LANCZOS) |
| | vis_image.paste(pred_zoomed, (padding, y_pos + label_height)) |
| |
|
| | |
| | base_name = os.path.splitext(os.path.basename(image_path))[0] |
| | output_path = os.path.join(output_folder, f'prediction_{base_name}.png') |
| | vis_image.save(output_path) |
| | print(f" Saved: {output_path}") |
| |
|
| | return pred_rows, pred_cols |
| |
|
| | def main(): |
| | import argparse |
| | parser = argparse.ArgumentParser(description='Visualize table split predictions on local images') |
| | parser.add_argument('--image-folder', type=str, required=True, |
| | help='Path to folder containing images') |
| | parser.add_argument('--output-folder', type=str, default='predictions_output', |
| | help='Path to folder for saving predictions (default: predictions_output)') |
| | parser.add_argument('--model-path', type=str, default=None, |
| | help='Path to trained model checkpoint (if not specified, will search common locations)') |
| | parser.add_argument('--threshold', type=float, default=0.5, |
| | help='Threshold for binary predictions (default: 0.5)') |
| | parser.add_argument('--num-images', type=int, default=-1, |
| | help='Number of images to process (-1 for all)') |
| | args = parser.parse_args() |
| |
|
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| | print(f"Device: {device}") |
| |
|
| | |
| | os.makedirs(args.output_folder, exist_ok=True) |
| | print(f"Output folder: {args.output_folder}") |
| |
|
| | |
| | print(f"\nLoading images from: {args.image_folder}") |
| | dataset = LocalImageDataset(args.image_folder) |
| |
|
| | |
| | print("\nLoading model...") |
| | model = SplitModel().to(device) |
| |
|
| | |
| | if args.model_path: |
| | checkpoint_path = args.model_path |
| | else: |
| | possible_paths = [ |
| | 'best_split_model.pth', |
| | '/home/ng6309/datascience/santhosh/experiments/tablet/best_split_model.pth', |
| | '/home/ng6309/datascience/santhosh/experiments/tablet/runs/tablet_split_20251006_214335/best_split_model.pth', |
| | ] |
| |
|
| | checkpoint_path = None |
| | for path in possible_paths: |
| | if os.path.exists(path): |
| | checkpoint_path = path |
| | break |
| |
|
| | if checkpoint_path is None or not os.path.exists(checkpoint_path): |
| | print("ERROR: No trained model found! Please specify --model-path or ensure model exists in default locations") |
| | return |
| |
|
| | print(f"Loading checkpoint from: {checkpoint_path}") |
| | checkpoint = torch.load(checkpoint_path, map_location=device) |
| | model.load_state_dict(checkpoint['model_state_dict']) |
| |
|
| | print(f"\nModel trained for {checkpoint['epoch']} epochs") |
| | print(f"Val loss: {checkpoint['val_loss']:.4f}") |
| | if 'val_h_f1' in checkpoint: |
| | print(f"Val H F1: {checkpoint['val_h_f1']:.3f}") |
| | print(f"Val V F1: {checkpoint['val_v_f1']:.3f}") |
| |
|
| | |
| | num_samples = len(dataset) if args.num_images == -1 else min(args.num_images, len(dataset)) |
| |
|
| | |
| | print("\n" + "="*60) |
| | print(f"Visualizing predictions for {num_samples} images") |
| | print("="*60) |
| |
|
| | for idx in range(num_samples): |
| | pred_rows, pred_cols = visualize_prediction(model, dataset, idx, device, args.output_folder, threshold=args.threshold) |
| |
|
| | print("\n" + "="*60) |
| | print(f"Completed processing {num_samples} images") |
| | print(f"All predictions saved to: {args.output_folder}") |
| | print("="*60) |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|