MAnet_BrainMRI / app.py
RHenigan
Make gradio app
44a9d05
raw
history blame
No virus
1.41 kB
import gradio as gr
from glob import glob
import os
import time
from PIL import Image
import albumentations as A
import cv2
import numpy as np
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import pandas as pd
from scipy.ndimage.morphology import binary_dilation
import segmentation_models_pytorch as smp
from sklearn.impute import SimpleImputer
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
from tqdm import tqdm
from tensorflow.keras.models import load_model
model = smp.MAnet(
encoder_name="efficientnet-b7",
encoder_weights="imagenet",
in_channels=3,
classes=1,
activation='sigmoid',)
transform = A.Compose([
A.ChannelDropout(p=0.3),
A.RandomBrightnessContrast(p=0.3),
A.ColorJitter(p=0.3),
])
model.load_state_dict(torch.load("weights.pt", map_location=torch.device('cpu')))
model.eval()
def segment(image):
image = transform(image=image)
image = image.get("image")
image = T.functional.to_tensor(image)
prediction = model(image[None, ...])
prediction = np.squeeze(prediction.detach().numpy())
return Image.fromarray(prediction)
iface = gr.Interface(fn=segment, inputs="image", outputs="image").launch()