pix2pix-map / app.py
bhadresh-savani's picture
fixed the model name
b550f13
raw history blame
No virus
732 Bytes
import gradio as gr
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from PIL import Image
from torchvision.utils import save_image
from huggan.pytorch.pix2pix.modeling_pix2pix import GeneratorUNet
transform = Compose(
[
Resize((256, 256), Image.BICUBIC),
ToTensor(),
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
model = GeneratorUNet.from_pretrained('huggan/pix2pix-maps')
def predict_fn(img):
inp = transform(img).unsqueeze(0)
out = model(inp)
save_image(out, 'out.png', normalize=True)
return 'out.png'
gr.Interface(predict_fn, inputs=gr.inputs.Image(type='pil'), outputs='image', examples=[['sample.jpg'], ['sample2.jpg'], ['sample3.jpg']]).launch()