import gradio as gr from glob import glob import os import time from PIL import Image import albumentations as A import cv2 import numpy as np import matplotlib.patches as mpatches import matplotlib.pyplot as plt import pandas as pd from scipy.ndimage.morphology import binary_dilation import segmentation_models_pytorch as smp from sklearn.impute import SimpleImputer from sklearn.model_selection import train_test_split import torch import torch.nn as nn from torch.optim import Adam from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.utils.data import Dataset, DataLoader from torchvision import transforms as T from tqdm import tqdm from tensorflow.keras.models import load_model model = smp.MAnet( encoder_name="efficientnet-b7", encoder_weights="imagenet", in_channels=3, classes=1, activation='sigmoid',) transform = A.Compose([ A.ChannelDropout(p=0.3), A.RandomBrightnessContrast(p=0.3), A.ColorJitter(p=0.3), ]) model.load_state_dict(torch.load("weights.pt", map_location=torch.device('cpu'))) model.eval() def segment(image): image = transform(image=image) image = image.get("image") image = T.functional.to_tensor(image) prediction = model(image[None, ...]) prediction = np.squeeze(prediction.detach().numpy()) return Image.fromarray(prediction) iface = gr.Interface(fn=segment, inputs="image", outputs="image").launch()