| import torch |
| import gradio as gr |
| from PIL import Image |
| import torchvision.transforms.functional as TF |
| import torchvision.models as models |
| import torch.nn as nn |
|
|
| |
| model = models.efficientnet_b4() |
| num_features = model.classifier[1].in_features |
| model.classifier[1] = nn.Linear(num_features, 2) |
|
|
| model.load_state_dict(torch.load("model/deep-image-squish-predictor-V0.pth", map_location=torch.device('cpu'))) |
| model.eval() |
|
|
| def predict(image): |
| width, height = image.size |
| ratio = width / height |
| if(width > height): |
| new_height = int(256 / ratio) |
| new_width = 256 |
| else: |
| new_width = int(256 * ratio) |
| new_height = 256 |
| |
| resized_image = TF.resize(image, (new_height, new_width)) |
|
|
| padded_image = Image.new("RGB", (256, 256)) |
| padded_image.paste(resized_image, (0, 0)) |
| image_tensor = TF.to_tensor(padded_image).unsqueeze(0) |
|
|
| |
| with torch.no_grad(): |
| output = model(image_tensor) |
| |
| wsr, hsr = output.squeeze().tolist() |
| if(wsr < hsr): |
| height = int(height * wsr) |
| else: |
| width = int(width * hsr) |
| reconstructed_image = TF.resize(image, (height, width)) |
| return f"Squish Ratio: (Width, Height)= ({wsr:.2f}, {hsr:.2f})", reconstructed_image |
|
|
| |
| examples = [ |
| ["example_images/image1.jpg"], |
| ["example_images/image2.jpg"], |
| ["example_images/image3.jpg"], |
| ["example_images/image4.jpg"] |
| ] |
|
|
| |
| iface = gr.Interface( |
| fn=predict, |
| inputs=gr.Image(type="pil"), |
| outputs=[gr.Textbox(label="Prediction"), gr.Image(type="pil", label="Reconstructed Aspect-Ratio")], |
| examples=examples, |
| title="Deep Image Squish Predictor", |
| description="Upload an image to see the predicted squish ratios." |
| ) |
|
|
| |
| iface.launch() |
|
|