MAnet_BrainMRI / app.py
RHenigan
simplify imports
c2ac5a3
raw
history blame contribute delete
No virus
722 Bytes
import gradio as gr
from PIL import Image
import numpy as np
import segmentation_models_pytorch as smp
import torch
from torchvision import transforms as T
from tensorflow.keras.models import load_model
model = smp.MAnet(
encoder_name="efficientnet-b7",
encoder_weights="imagenet",
in_channels=3,
classes=1,
activation='sigmoid',)
model.load_state_dict(torch.load("weights.pt", map_location=torch.device('cpu')))
model.eval()
def segment(image):
image = T.functional.to_tensor(image)
prediction = model(image[None, ...])
prediction = np.squeeze(prediction.detach().numpy())
return Image.fromarray(prediction)
iface = gr.Interface(fn=segment, inputs="image", outputs="image").launch()