bnsapa commited on
Commit
11d1d78
1 Parent(s): 93af1e0

Add image segmentation functionality to app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import shutil
5
+ import os
6
+ import torch
7
+ import TwinLite as net
8
+ from PIL import Image
9
+ from dotenv import load_dotenv
10
+ load_dotenv()
11
+
12
+ share = os.getenv("SHARE", False)
13
+
14
+
15
+ model = net.TwinLiteNet()
16
+ import cv2
17
+
18
+ def Run(model,img):
19
+ img = cv2.resize(img, (640, 360))
20
+ img_rs=img.copy()
21
+
22
+ img = img[:, :, ::-1].transpose(2, 0, 1)
23
+ img = np.ascontiguousarray(img)
24
+ img=torch.from_numpy(img)
25
+ img = torch.unsqueeze(img, 0) # add a batch dimension
26
+ img=img.float() / 255.0
27
+ img = img
28
+ with torch.no_grad():
29
+ img_out = model(img)
30
+ x0=img_out[0]
31
+ x1=img_out[1]
32
+
33
+ _,da_predict=torch.max(x0, 1)
34
+ _,ll_predict=torch.max(x1, 1)
35
+
36
+ DA = da_predict.byte().cpu().data.numpy()[0]*255
37
+ LL = ll_predict.byte().cpu().data.numpy()[0]*255
38
+ img_rs[DA>100]=[255,0,0]
39
+ img_rs[LL>100]=[0,255,0]
40
+
41
+ return img_rs
42
+
43
+
44
+ model = net.TwinLiteNet()
45
+ model = torch.nn.DataParallel(model)
46
+ model.load_state_dict(torch.load('fine-tuned-model.pth', map_location=torch.device('cpu')))
47
+ model.eval()
48
+
49
+
50
+ def predict(image):
51
+ image = Image.fromarray(image.astype('uint8'), 'RGB')
52
+ image.save("input.png")
53
+ img = cv2.imread("input.png")
54
+ img = Run(model, img)
55
+ cv2.imwrite("sample.png", img)
56
+ prediction = Image.open("sample.png")
57
+ return prediction
58
+
59
+
60
+ iface = gr.Interface(fn=predict, inputs="image", outputs="image", title="Image Segmentation")
61
+
62
+ if __name__ == "__main__":
63
+ if share:
64
+ server = "0.0.0.0"
65
+ else:
66
+ server = "127.0.0.1"
67
+ iface.launch(server_name = server)