Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import albumentations as A | |
from albumentations.pytorch import ToTensorV2 | |
from model import Generator | |
import numpy as np | |
from PIL import Image | |
import gradio as gr | |
model = Generator(3) | |
model_path = 'state_dict.pth' | |
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) | |
transform = A.Compose([ | |
A.Resize(width=256, height=256), | |
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0), | |
ToTensorV2() | |
]) | |
def main(image): | |
augmented = transform(image=image) | |
tensor_img = augmented['image'] | |
with torch.inference_mode(): | |
pred = model(tensor_img.unsqueeze(0)) | |
pred = pred.squeeze(0).permute(1, 2, 0).numpy() | |
return pred | |
app = gr.Interface( | |
fn=main, | |
inputs=gr.Image(), | |
outputs=gr.Image(), | |
examples=['1.jpg', '2.jpg', '3.jpg', '4.jpg'] | |
) | |
app.launch() | |