road-detection / app.py
bnsapa's picture
Add image segmentation functionality to app.py
11d1d78
raw
history blame
1.64 kB
import gradio as gr
import torch
import numpy as np
import shutil
import os
import torch
import TwinLite as net
from PIL import Image
from dotenv import load_dotenv
load_dotenv()
share = os.getenv("SHARE", False)
model = net.TwinLiteNet()
import cv2
def Run(model,img):
img = cv2.resize(img, (640, 360))
img_rs=img.copy()
img = img[:, :, ::-1].transpose(2, 0, 1)
img = np.ascontiguousarray(img)
img=torch.from_numpy(img)
img = torch.unsqueeze(img, 0) # add a batch dimension
img=img.float() / 255.0
img = img
with torch.no_grad():
img_out = model(img)
x0=img_out[0]
x1=img_out[1]
_,da_predict=torch.max(x0, 1)
_,ll_predict=torch.max(x1, 1)
DA = da_predict.byte().cpu().data.numpy()[0]*255
LL = ll_predict.byte().cpu().data.numpy()[0]*255
img_rs[DA>100]=[255,0,0]
img_rs[LL>100]=[0,255,0]
return img_rs
model = net.TwinLiteNet()
model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load('fine-tuned-model.pth', map_location=torch.device('cpu')))
model.eval()
def predict(image):
image = Image.fromarray(image.astype('uint8'), 'RGB')
image.save("input.png")
img = cv2.imread("input.png")
img = Run(model, img)
cv2.imwrite("sample.png", img)
prediction = Image.open("sample.png")
return prediction
iface = gr.Interface(fn=predict, inputs="image", outputs="image", title="Image Segmentation")
if __name__ == "__main__":
if share:
server = "0.0.0.0"
else:
server = "127.0.0.1"
iface.launch(server_name = server)