Update app.py
Browse files
app.py
CHANGED
@@ -4,7 +4,7 @@ import gradio as gr
|
|
4 |
import os
|
5 |
#os.system("sudo apt-get install nvIDia-cuda-toolkit")
|
6 |
os.system("pip3 install torch")
|
7 |
-
os.system("/usr/local/bin/python -m pip install --upgrade pip")
|
8 |
os.system("pip3 install collections")
|
9 |
os.system("pip3 install torchvision")
|
10 |
os.system("pip3 install einops")
|
@@ -87,8 +87,8 @@ def test(gpu_id, net, img_list, group_size, img_size):
|
|
87 |
group_img[i]=img_transform(Image.fromarray(img_list[i]))
|
88 |
_,pred_mask=net(group_img*1)
|
89 |
pred_mask=(pred_mask.detach().squeeze()*255)#.numpy().astype(np.uint8)
|
90 |
-
pred_mask=[F.interpolate(pred_mask[i].reshape(1,1,pred_mask[i].shape[-2],pred_mask[i].shape[-1]),size=(size,size),mode='bilinear').squeeze().numpy().astype(np.uint8) for i in range(5)]
|
91 |
-
|
92 |
#for i in range(5):
|
93 |
# print(img_list[i].shape,pred_mask[i].shape)
|
94 |
pred_mask=[crf_refine(img_list[i],pred_mask[i]) for i in range(5)]
|
|
|
4 |
import os
|
5 |
#os.system("sudo apt-get install nvIDia-cuda-toolkit")
|
6 |
os.system("pip3 install torch")
|
7 |
+
#os.system("/usr/local/bin/python -m pip install --upgrade pip")
|
8 |
os.system("pip3 install collections")
|
9 |
os.system("pip3 install torchvision")
|
10 |
os.system("pip3 install einops")
|
|
|
87 |
group_img[i]=img_transform(Image.fromarray(img_list[i]))
|
88 |
_,pred_mask=net(group_img*1)
|
89 |
pred_mask=(pred_mask.detach().squeeze()*255)#.numpy().astype(np.uint8)
|
90 |
+
#pred_mask=[F.interpolate(pred_mask[i].reshape(1,1,pred_mask[i].shape[-2],pred_mask[i].shape[-1]),size=(size,size),mode='bilinear').squeeze().numpy().astype(np.uint8) for i in range(5)]
|
91 |
+
pred_mask=[crf_refine(((group_img[i]-group_img[i].min())/(group_img[i].max()-group_img[i].min())*255).permute(1,2,0).contiguous().numpy().astype(np.uint8),pred_mask[i]) for i in range(5)]
|
92 |
#for i in range(5):
|
93 |
# print(img_list[i].shape,pred_mask[i].shape)
|
94 |
pred_mask=[crf_refine(img_list[i],pred_mask[i]) for i in range(5)]
|