|
import os |
|
import torch |
|
import torch.nn as nn |
|
from PIL import Image |
|
from torch.utils.data import Dataset, DataLoader |
|
from torchvision.transforms import ToTensor |
|
import time |
|
from datetime import datetime |
|
import multiprocessing |
|
|
|
def get_optimal_threads(): |
|
"""Calculate optimal number of threads based on CPU cores""" |
|
return max(1, multiprocessing.cpu_count() - 1) |
|
|
|
|
|
class DenoisingModel(nn.Module): |
|
def __init__(self): |
|
super(DenoisingModel, self).__init__() |
|
|
|
self.enc1 = nn.Sequential( |
|
nn.Conv2d(3, 64, 3, padding=1), |
|
nn.ReLU(), |
|
nn.Conv2d(64, 64, 3, padding=1), |
|
nn.ReLU() |
|
) |
|
self.pool1 = nn.MaxPool2d(2, 2) |
|
|
|
|
|
self.up1 = nn.ConvTranspose2d(64, 64, 2, stride=2) |
|
self.dec1 = nn.Sequential( |
|
nn.Conv2d(64, 64, 3, padding=1), |
|
nn.ReLU(), |
|
nn.Conv2d(64, 3, 3, padding=1) |
|
) |
|
|
|
def forward(self, x): |
|
|
|
e1 = self.enc1(x) |
|
p1 = self.pool1(e1) |
|
|
|
|
|
u1 = self.up1(p1) |
|
d1 = self.dec1(u1) |
|
return d1 |
|
|
|
class DenoiseDataset(Dataset): |
|
def __init__(self, noisy_folder, target_folder, patch_size=256): |
|
self.noisy_folder = noisy_folder |
|
self.target_folder = target_folder |
|
self.patch_size = patch_size |
|
self.image_pairs = [ |
|
(os.path.join(noisy_folder, f), os.path.join(target_folder, f.replace("_noisy", "_target"))) |
|
for f in os.listdir(noisy_folder) if "_noisy" in f |
|
] |
|
self.transform = ToTensor() |
|
|
|
print(f"Dataset initialization:") |
|
print(f"- Noisy folder: {noisy_folder}") |
|
print(f"- Target folder: {target_folder}") |
|
print(f"- Patch size: {patch_size}") |
|
print(f"- Found {len(self.image_pairs)} image pairs") |
|
|
|
if not self.image_pairs: |
|
raise ValueError("No image pairs found. Check if noisy and target images are correctly named.") |
|
|
|
|
|
self.patches_per_image = {} |
|
for noisy_path, _ in self.image_pairs: |
|
try: |
|
self.patches_per_image[noisy_path] = self._get_num_patches_per_image(noisy_path) |
|
except Exception as e: |
|
print(f"Error calculating patches for {noisy_path}: {e}. Skipping this image pair.") |
|
self.image_pairs = [(n, t) for n, t in self.image_pairs if n != noisy_path] |
|
|
|
self.total_patches = sum(self.patches_per_image.values()) |
|
|
|
def __len__(self): |
|
return self.total_patches |
|
|
|
def __getitem__(self, idx): |
|
image_idx = 0 |
|
cumulative_patches = 0 |
|
|
|
for i, (noisy_path, _) in enumerate(self.image_pairs): |
|
num_patches = self.patches_per_image[noisy_path] |
|
if cumulative_patches + num_patches > idx: |
|
image_idx = i |
|
break |
|
cumulative_patches += num_patches |
|
|
|
patch_idx = idx - cumulative_patches |
|
noisy_path, target_path = self.image_pairs[image_idx] |
|
|
|
try: |
|
noisy_image = self._load_image(noisy_path) |
|
target_image = self._load_image(target_path) |
|
except Exception as e: |
|
print(f"Error loading image pair ({noisy_path}, {target_path}): {e}. Returning default values.") |
|
return torch.zeros((3, self.patch_size, self.patch_size)), torch.zeros((3, self.patch_size, self.patch_size)) |
|
|
|
try: |
|
noisy_patch = self._get_patch(noisy_image, patch_idx) |
|
target_patch = self._get_patch(target_image, patch_idx) |
|
except Exception as e: |
|
print(f"Error getting patch from image pair ({noisy_path}, {target_path}): {e}. Returning default values.") |
|
return torch.zeros((3, self.patch_size, self.patch_size)), torch.zeros((3, self.patch_size, self.patch_size)) |
|
|
|
return noisy_patch, target_patch |
|
|
|
def _load_image(self, image_path): |
|
try: |
|
image = Image.open(image_path).convert("RGB") |
|
return self.transform(image) |
|
except Exception as e: |
|
raise Exception(f"Error loading image {image_path}: {e}") |
|
|
|
def _get_num_patches_per_image(self, image_path): |
|
try: |
|
image = Image.open(image_path) |
|
width, height = image.size |
|
num_patches = (width // self.patch_size) * (height // self.patch_size) |
|
return num_patches |
|
except Exception as e: |
|
raise Exception(f"Error calculating patches for {image_path}: {e}") |
|
|
|
def _get_patch(self, image, patch_idx): |
|
width, height = image.shape[2], image.shape[1] |
|
patches_per_row = width // self.patch_size |
|
row = patch_idx // patches_per_row |
|
col = patch_idx % patches_per_row |
|
|
|
x_start = col * self.patch_size |
|
y_start = row * self.patch_size |
|
return image[:, y_start:y_start+self.patch_size, x_start:x_start+self.patch_size] |
|
|
|
|
|
def train_model(noisy_dir, target_dir, epochs, batch_size, learning_rate, save_interval, num_workers): |
|
|
|
if torch.cuda.is_available(): |
|
torch.backends.cudnn.benchmark = True |
|
device = torch.device("cuda") |
|
print(f"\nUsing GPU: {torch.cuda.get_device_name(0)}") |
|
print(f"CUDA version: {torch.version.cuda}") |
|
else: |
|
device = torch.device("cpu") |
|
print("\nNo GPU detected, using CPU") |
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
output_dir = f"model_checkpoints_{timestamp}" |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
print("\nTraining Configuration:") |
|
print(f"- Number of epochs: {epochs}") |
|
print(f"- Batch size: {batch_size}") |
|
print(f"- Learning rate: {learning_rate}") |
|
print(f"- Number of worker threads: {num_workers}") |
|
print(f"- Model checkpoint directory: {output_dir}") |
|
|
|
|
|
dataset = DenoiseDataset(noisy_dir, target_dir) |
|
dataloader = DataLoader( |
|
dataset, |
|
batch_size=batch_size, |
|
shuffle=True, |
|
num_workers=num_workers, |
|
pin_memory=True if torch.cuda.is_available() else False |
|
) |
|
|
|
|
|
model = DenoisingModel().to(device) |
|
if torch.cuda.device_count() > 1: |
|
print(f"Using {torch.cuda.device_count()} GPUs!") |
|
model = nn.DataParallel(model) |
|
|
|
criterion = nn.MSELoss() |
|
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) |
|
|
|
|
|
total_batches = len(dataloader) |
|
start_time = time.time() |
|
|
|
print("\nStarting training...") |
|
for epoch in range(epochs): |
|
epoch_loss = 0.0 |
|
for batch_idx, (noisy_patches, target_patches) in enumerate(dataloader): |
|
|
|
noisy_patches = noisy_patches.to(device, non_blocking=True) |
|
target_patches = target_patches.to(device, non_blocking=True) |
|
|
|
|
|
outputs = model(noisy_patches) |
|
loss = criterion(outputs, target_patches) |
|
|
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
|
|
epoch_loss += loss.item() |
|
|
|
|
|
if (batch_idx + 1) % 100 == 0: |
|
elapsed_time = time.time() - start_time |
|
print(f"Epoch [{epoch+1}/{epochs}], " |
|
f"Batch [{batch_idx+1}/{total_batches}], " |
|
f"Loss: {loss.item():.6f}, " |
|
f"Time: {elapsed_time:.2f}s") |
|
|
|
|
|
if (batch_idx + 1) % save_interval == 0: |
|
checkpoint_path = os.path.join(output_dir, |
|
f"denoising_model_epoch{epoch+1}_batch{batch_idx+1}.pth") |
|
torch.save({ |
|
'epoch': epoch, |
|
'batch': batch_idx, |
|
'model_state_dict': model.state_dict(), |
|
'optimizer_state_dict': optimizer.state_dict(), |
|
'loss': loss.item(), |
|
}, checkpoint_path) |
|
print(f"\nCheckpoint saved: {checkpoint_path}") |
|
|
|
|
|
avg_epoch_loss = epoch_loss / total_batches |
|
print(f"\nEpoch [{epoch+1}/{epochs}] completed. " |
|
f"Average loss: {avg_epoch_loss:.6f}") |
|
|
|
|
|
checkpoint_path = os.path.join(output_dir, f"denoising_model_epoch{epoch+1}.pth") |
|
torch.save({ |
|
'epoch': epoch, |
|
'model_state_dict': model.state_dict(), |
|
'optimizer_state_dict': optimizer.state_dict(), |
|
'loss': avg_epoch_loss, |
|
}, checkpoint_path) |
|
print(f"Epoch checkpoint saved: {checkpoint_path}") |
|
|
|
print("\nTraining completed!") |
|
print(f"Total training time: {time.time() - start_time:.2f} seconds") |
|
|
|
|
|
final_model_path = os.path.join(output_dir, "denoising_model_final.pth") |
|
torch.save(model.state_dict(), final_model_path) |
|
print(f"Final model saved: {final_model_path}") |
|
|
|
def main(): |
|
noisy_dir = 'noisy_images' |
|
target_dir = 'target_images' |
|
epochs = 10 |
|
batch_size = 4 |
|
learning_rate = 0.001 |
|
save_interval = 1000 |
|
num_workers = get_optimal_threads() |
|
|
|
train_model(noisy_dir, target_dir, epochs, batch_size, learning_rate, save_interval, num_workers) |
|
|
|
if __name__ == "__main__": |
|
main() |