File size: 6,645 Bytes
57db94b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
import os
import torch
import cv2
import numpy as np
import shutil
from combined import IFNet, warp
import torchvision.transforms as transforms
import time
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Create eval_data directory
eval_dir = "eval_data"
if os.path.exists(eval_dir):
# Clear the directory if it already exists
shutil.rmtree(eval_dir)
os.makedirs(eval_dir)
print(f"Created directory: {eval_dir}")
# Initialize model
model = IFNet().to(device)
# Load checkpoint if available
checkpoint_path = "save_checkpoints/model_epoch_50.pth"
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, weights_only=False, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded model from epoch {checkpoint['epoch']} with PSNR: {checkpoint.get('psnr', 'N/A')} dB")
else:
print("No checkpoint found, using uninitialized model")
model.eval()
# Define the preprocessing transforms as in your test script
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# Function to preprocess images (same as in your test.py)
def preprocess_images(img0_path, img1_path, gt_path=None):
# Read images
img0 = cv2.imread(img0_path)
img1 = cv2.imread(img1_path)
# Check if images were loaded successfully
if img0 is None or img1 is None:
raise ValueError(f"Could not read images: {img0_path}, {img1_path}")
gt = None
if gt_path and os.path.exists(gt_path):
gt = cv2.imread(gt_path)
if gt is None:
print(f"Warning: Could not read ground truth image: {gt_path}")
gt = None
else:
gt = cv2.cvtColor(gt, cv2.COLOR_BGR2RGB)
img0 = cv2.cvtColor(img0, cv2.COLOR_BGR2RGB)
img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB)
# Save original dimensions for later
original_size = (img0.shape[0], img0.shape[1])
# Store original images for display
orig_img0 = img0.copy()
orig_img1 = img1.copy()
# Resize to model's expected input size
img0_resized = cv2.resize(img0, (256, 256))
img1_resized = cv2.resize(img1, (256, 256))
# Apply transformations
img0_tensor = transform(img0_resized)
img1_tensor = transform(img1_resized)
# Stack tensors - make sure everything is on the same device
input_tensor = torch.cat((img0_tensor, img1_tensor), 0).unsqueeze(0).to(device)
return input_tensor, original_size, orig_img0, orig_img1, gt
# Function to denormalize and convert tensor to image (same as in your test.py)
def tensor_to_image(tensor):
# Make sure tensor is on CPU for numpy conversion
tensor = tensor.cpu()
# Denormalize
tensor = tensor * 0.5 + 0.5
tensor = tensor.clamp(0, 1)
# Convert to numpy array
img = tensor.numpy().transpose(1, 2, 0) * 255
return img.astype(np.uint8)
# Counter for output filename indexing
counter = 1
# Process all subdirectories in the dataset
dataset_dir = "datasets/test_2k"
print(f"Looking for frames in: {dataset_dir}")
# Get all immediate subdirectories in the test_2k folder
subdirs = []
try:
# Get all items in the dataset directory
items = os.listdir(dataset_dir)
# Filter to get only directories
subdirs = [item for item in items if os.path.isdir(os.path.join(dataset_dir, item))]
print(f"Found {len(subdirs)} subdirectories in {dataset_dir}")
except Exception as e:
print(f"Error listing subdirectories: {e}")
# Process each subdirectory
for subdir in subdirs:
subdir_path = os.path.join(dataset_dir, subdir)
# Check if this directory contains the required frames
frame1_path = os.path.join(subdir_path, "frame1.png")
frame2_path = os.path.join(subdir_path, "frame2.png")
frame3_path = os.path.join(subdir_path, "frame3.png")
# Check for different possible extensions if .png doesn't exist
if not os.path.exists(frame1_path):
for ext in ['.jpg', '.jpeg', '']:
test_path = os.path.join(subdir_path, f"frame1{ext}")
if os.path.exists(test_path):
frame1_path = test_path
break
if not os.path.exists(frame2_path):
for ext in ['.jpg', '.jpeg', '']:
test_path = os.path.join(subdir_path, f"frame2{ext}")
if os.path.exists(test_path):
frame2_path = test_path
break
if not os.path.exists(frame3_path):
for ext in ['.jpg', '.jpeg', '']:
test_path = os.path.join(subdir_path, f"frame3{ext}")
if os.path.exists(test_path):
frame3_path = test_path
break
if not (os.path.exists(frame1_path) and os.path.exists(frame2_path) and os.path.exists(frame3_path)):
print(f"Skipping {subdir_path} - missing required frames")
continue
print(f"Processing {subdir_path}")
try:
# Preprocess images
input_tensor, original_size, orig_img0, orig_img1, gt = preprocess_images(
frame1_path, frame3_path, frame2_path
)
# Generate interpolation
start_time = time.time()
with torch.no_grad():
# Call the model to generate the interpolated frame
flow, mask, interpolated = model(input_tensor)
inference_time = time.time() - start_time
print(f"Inference time: {inference_time:.4f} seconds")
# Convert output tensor to image
interpolated_img = tensor_to_image(interpolated[0])
# Resize back to original dimensions if needed
if interpolated_img.shape[:2] != original_size:
interpolated_img = cv2.resize(interpolated_img, (original_size[1], original_size[0]))
# Save the interpolated frame and ground truth
output_io = os.path.join(eval_dir, f"io{counter}.png")
output_gt = os.path.join(eval_dir, f"gt{counter}.png")
# Save images (convert RGB to BGR for OpenCV)
cv2.imwrite(output_io, cv2.cvtColor(interpolated_img, cv2.COLOR_RGB2BGR))
cv2.imwrite(output_gt, cv2.cvtColor(gt, cv2.COLOR_RGB2BGR))
print(f"Saved pair {counter}: {output_io} and {output_gt}")
# Increment the counter for the next pair
counter += 1
except Exception as e:
print(f"Error processing {subdir_path}: {e}")
import traceback
traceback.print_exc()
print(f"Processing complete. Generated {counter-1} pairs of images in {eval_dir}") |