import gradio as gr from PIL import Image import numpy as np import segmentation_models_pytorch as smp import torch from torchvision import transforms as T from tensorflow.keras.models import load_model model = smp.MAnet( encoder_name="efficientnet-b7", encoder_weights="imagenet", in_channels=3, classes=1, activation='sigmoid',) model.load_state_dict(torch.load("weights.pt", map_location=torch.device('cpu'))) model.eval() def segment(image): image = T.functional.to_tensor(image) prediction = model(image[None, ...]) prediction = np.squeeze(prediction.detach().numpy()) return Image.fromarray(prediction) iface = gr.Interface(fn=segment, inputs="image", outputs="image").launch()