Spaces:
Runtime error
Runtime error
import gradio as gr | |
from PIL import Image | |
import torchvision | |
import torch | |
# load model | |
MODELS_TYPE = ["ModelA", "ModelB", "ModelC"] | |
def predict(input_image, model_name): | |
pil_image = Image.fromarray(input_image.astype('uint8'), 'RGB') | |
# transform image to torch and do preprocessing | |
torch_image = torchvision.transforms.ToTensor()(pil_image) | |
# model predict | |
prediction = torch.rand(torch_image.shape) | |
# transform torch to image | |
predicted_pil_image = torchvision.transforms.ToPILImage()(prediction) | |
# return correct image | |
return predicted_pil_image | |
iface = gr.Interface( | |
fn=predict, | |
inputs=[ | |
gr.Image(shape=(512,512)), | |
gr.inputs.Radio(MODELS_TYPE) | |
], | |
outputs=gr.Image(shape=(512,512)), | |
examples=[ | |
["demo_imgs/fake.jpg", MODELS_TYPE[0]] # use real image | |
], | |
title="DTM Estimation", | |
description="This demo predict a DTM..." | |
) | |
iface.launch() |