|
|
| import torch |
| from torchvision import transforms |
| from PIL import Image |
| import matplotlib.pyplot as plt |
| import sys |
|
|
| from model import Autoencoder |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| model = Autoencoder().to(device) |
|
|
| state_dict = torch.load("autoencoder_celeba.pth", map_location=device) |
|
|
| |
| new_state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()} |
|
|
| model.load_state_dict(new_state_dict) |
| model.eval() |
|
|
| image_path = sys.argv[1] |
|
|
| transform = transforms.Compose([ |
| transforms.Resize((128, 128)), |
| transforms.ToTensor() |
| ]) |
|
|
| image = Image.open(image_path).convert("RGB") |
| input_tensor = transform(image).unsqueeze(0).to(device) |
|
|
| with torch.no_grad(): |
| output = model(input_tensor) |
|
|
| input_np = input_tensor.squeeze().cpu().numpy() |
| output_np = output.squeeze().cpu().numpy() |
|
|
| fig, axes = plt.subplots(1, 2, figsize=(8,4)) |
|
|
| axes[0].imshow(input_np.transpose(1,2,0)) |
| axes[0].set_title("Original") |
| axes[0].axis("off") |
|
|
| axes[1].imshow(output_np.transpose(1,2,0)) |
| axes[1].set_title("Reconstructed") |
| axes[1].axis("off") |
|
|
| plt.show() |
|
|