GDNet_2025 / app.py
djl234's picture
Update app.py
91a4348 verified
import tqdm
#import fastCNN
import numpy as np
import gradio as gr
import os
#os.system("sudo apt-get install nvIDia-cuda-toolkit")
os.system("pip3 install torch")
#os.system("/usr/local/bin/python -m pip install --upgrade pip")
os.system("pip3 install collections")
os.system("pip3 install torchvision")
os.system("pip3 install einops")
os.system("pip3 install opencv-python")
aaaa=0
#os.system("pip3 install pydensecrf")
#os.system("pip install argparse")
#import pydensecrf.densecrf as dcrf
from PIL import Image
import torch
import cv2
import torch.nn.functional as F
from torchvision import transforms
from model_video import build_model
import numpy as np
import collections
def show_coord(evt: gr.SelectData):
return f"{evt.index[0]},{evt.index[1]}"
def generate_mask(model_type,img, coord):
#x, y = map(int, coord.split(','))
#
mask = sepia(model_type,(img*0.999999).astype(np.uint8),(img*0.999999).astype(np.uint8),(img*0.999999).astype(np.uint8),(img*0.999999).astype(np.uint8),(img*0.999999).astype(np.uint8), stack_image=False)
mask = F.interpolate(torch.from_numpy(mask).unsqueeze(0).unsqueeze(0),size=[img.shape[0],img.shape[1]],mode='bilinear').squeeze().numpy()
col = torch.from_numpy(mask).squeeze().unsqueeze(2).repeat(1,1,3)
col=col/col.max()
mask_torch=torch.from_numpy(mask).squeeze().unsqueeze(2).repeat(1,1,3)
mask_torch=mask_torch/mask_torch.max()
#col[:,:,0]=0
img=img/img.max()*255
col=col*255
col[:,:,0]=0
mix = (1-mask_torch)*img+mask_torch*img*0.5+mask_torch*col*0.5
return mix.numpy().astype(np.uint8)#overlay_mask(img, mask)
def create_mode2_interface():
with gr.Blocks() as mode2:
with gr.Column():
img_input = gr.Image(
type="numpy",
sources=["upload"], # 正确复数形式参数[2](@ref)
label="点击上传图片并选择点",
interactive=True
)
# 坐标存储组件
coord_store = gr.Textbox(visible=False)
# 绑定点击事件
@img_input.select(inputs=[], outputs=coord_store)
def capture_coordinates(evt: gr.SelectData):
return f"{evt.index[0]},{evt.index[1]}"
# 修改3:正确绑定点击事件
@img_input.select(inputs=img_input, outputs=coord_store)
def store_coordinate(evt: gr.SelectData):
return f"{evt.index[0]},{evt.index[1]}"
btn = gr.Button("生成分割掩码")
mask_output = gr.Image(label="分割结果")
btn.click(
generate_mask,
inputs=[img_input, coord_store],
outputs=mask_output
)
return mode2
def create_mode3_interface():
with gr.Blocks() as mode2:
with gr.Column():
img_input = gr.Image(
type="numpy",
sources=["upload"], # 正确复数形式参数[2](@ref)
label="点击上传图片并选择框",
interactive=True
)
# 坐标存储组件
coord_store = gr.Textbox(visible=False)
# 绑定点击事件
@img_input.select(inputs=[], outputs=coord_store)
def capture_coordinates(evt: gr.SelectData):
return f"{evt.index[0]},{evt.index[1]}"
# 修改3:正确绑定点击事件
@img_input.select(inputs=img_input, outputs=coord_store)
def store_coordinate(evt: gr.SelectData):
return f"{evt.index[0]},{evt.index[1]}"
btn = gr.Button("生成分割掩码")
mask_output = gr.Image(label="分割结果")
btn.click(
generate_mask,
inputs=[img_input, coord_store],
outputs=mask_output
)
return mode2
#import argparse
device='cpu'
net = build_model(device).to(device)
#net=torch.nn.DataParallel(net)
model_path = 'image_best.pth'
print(model_path)
weight=torch.load(model_path,map_location=torch.device(device))
#print(type(weight))
new_dict=collections.OrderedDict()
for k in weight.keys():
new_dict[k[len('module.'):]]=weight[k]
net.load_state_dict(new_dict)
net.eval()
net = net.to(device)
def test(gpu_id, net, img_list, group_size, img_size,stack_image=True):
print('test')
#device=device
hl,wl=[_.shape[0] for _ in img_list],[_.shape[1] for _ in img_list]
img_transform = transforms.Compose([transforms.Resize((img_size, img_size)), transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
img_transform_gray = transforms.Compose([transforms.Resize((img_size, img_size)), transforms.ToTensor(),
transforms.Normalize(mean=[0.449], std=[0.226])])
with torch.no_grad():
group_img=torch.rand(5,3,224,224)
for i in range(5):
group_img[i]=img_transform(Image.fromarray(img_list[i]))
_,pred_mask=net(group_img*1)
pred_mask=(pred_mask.detach().squeeze()*255)#.numpy().astype(np.uint8)
#pred_mask=[F.interpolate(pred_mask[i].reshape(1,1,pred_mask[i].shape[-2],pred_mask[i].shape[-1]),size=(size,size),mode='bilinear').squeeze().numpy().astype(np.uint8) for i in range(5)]
img_resize=[((group_img[i]-group_img[i].min())/(group_img[i].max()-group_img[i].min())*255).permute(1,2,0).contiguous().numpy().astype(np.uint8)
for i in range(5)]
pred_mask=[(pred_mask[i].numpy().astype(np.uint8)) for i in range(5)]#[(img_resize[i],pred_mask[i].numpy().astype(np.uint8)) for i in range(5)]
if not stack_image:
return pred_mask[0]
#for i in range(5):
# print(img_list[i].shape,pred_mask[i].shape)
#pred_mask=[crf_refine(img_list[i],pred_mask[i]) for i in range(5)]
print(pred_mask[0].shape)
white=(torch.ones(2,pred_mask[0].shape[1],3)*255).long()
result = [torch.cat([torch.from_numpy(img_resize[i]),white,torch.from_numpy(pred_mask[i]).unsqueeze(2).repeat(1,1,3)],dim=0).numpy() for i in range(5)]
#w, h = 224,224#Image.open(image_list[i][j]).size
#result = result.resize((w, h), Image.BILINEAR)
#result.convert('L').save('0.png')
print('done')
return result
img_lst=[(torch.rand(352,352,3)*255).numpy().astype(np.uint8) for i in range(5)]
#simly test
res=test('cpu',net,img_lst,5,224)
'''for i in range(5):
assert res[i].shape[0]==352 and res[i].shape[1]==352 and res[i].shape[2]==3'''
def sepia(model_type,img1,img2,img3,img4,img5,stack_image=True):
print('sepia')
print(img1.shape,img2.shape,img3.shape,img4.shape,img5.shape)
'''ans=[]
print(len(input_imgs))
for input_img in input_imgs:
sepia_filter = np.array(
[[0.393, 0.769, 0.189], [0.349, 0.686, 0.168], [0.272, 0.534, 0.131]]
)
sepia_img = input_img.dot(sepia_filter.T)
sepia_img /= sepia_img.max()
ans.append(input_img)'''
img_list=[img1,img2,img3,img4,img5]
h_list,w_list=[_.shape[0] for _ in img_list],[_.shape[1] for _ in img_list]
#print(type(img1))
#print(img1.shape)
result_list=test(device,net,img_list,5,224,stack_image)
if not stack_image:
return result_list
#result_list=[result_list[i].resize((w_list[i], h_list[i]), Image.BILINEAR) for i in range(5)]
img1,img2,img3,img4,img5=result_list#test('cpu',net,img_list,5,224)
white=(torch.ones(img1.shape[0],2,3)*255).numpy().astype(np.uint8)
return np.concatenate([img1,white,img2,white,img3,white,img4,white,img5],axis=1)
#gr.Image(shape=(224, 2))
#demo = gr.Interface(sepia, inputs=["image","image","image","image","image"], outputs=["image","image","image","image","image"])#gr.Interface(sepia, gr.Image(shape=(200, 200)), "image")
#demo = gr.Interface(sepia, inputs=["image","image","image","image","image"], outputs=["image"])
#demo.launch(debug=True)
#replace Interface with Blocks
def create_mode1_interface():
with gr.Blocks() as demo:
with gr.Row():
# 创建5列网格布局
with gr.Column(scale=1, min_width=150):
input1 = gr.Image(label="image1", type="numpy")
with gr.Column(scale=1, min_width=150):
input2 = gr.Image(label="image2", type="numpy")
with gr.Column(scale=1, min_width=150):
input3 = gr.Image(label="image3", type="numpy")
with gr.Column(scale=1, min_width=150):
input4 = gr.Image(label="image4", type="numpy")
with gr.Column(scale=1, min_width=150):
input5 = gr.Image(label="image5", type="numpy")
btn = gr.Button("start processing")
with gr.Row():
output = gr.Image(label="output", type="numpy")
#bind function
btn.click(
fn=sepia,
inputs=[input1, input2, input3, input4, input5],
outputs=output
)
with gr.Blocks(title="交互式图像组分割系统") as demo:
# 模式选择器
with gr.Row():
mode = gr.Radio(
["多图协同分割", "点提示交互分割","框提示交互分割"],
value="多图协同分割",
label="运行模式"
)
model_selector = gr.Dropdown(
choices=["RepViT-SAM", "EdgeSAM", "SAM-H"],
value="SAM-H",
label="选择模型",
container=False # 去除默认容器边框
)
# 使用Tab容器替代独立Blocks
with gr.Tabs() as mode_container:
with gr.Tab("多图模式", id=0) as tab1:
# 模式1界面组件
with gr.Row():
inputs = [gr.Image(type="numpy", label=f"图像{i+1}") for i in range(5)]
process_btn = gr.Button("开始处理")
output_img = gr.Image(label="处理结果")
process_btn.click(
sepia,
inputs=[model_selector]+inputs,
outputs=output_img
)
with gr.Tab("点选交互模式", id=1) as tab2:
# 模式2界面组件
img_input = gr.Image(type="numpy", label="点击上传图片并选择点")
coord_store = gr.Textbox(visible=False)
mask_btn = gr.Button("生成分割掩码")
mask_output = gr.Image(label="分割结果")
@img_input.select(inputs=[], outputs=coord_store)
def store_coordinate(evt: gr.SelectData):
return f"{evt.index[0]},{evt.index[1]}"
mask_btn.click(
generate_mask,
inputs=[model_selector,img_input, coord_store],
outputs=mask_output
)
with gr.Tab("框选交互模式", id=2) as tab3:
# 模式2界面组件
img_input = gr.Image(type="numpy", label="点击上传图片并选择框")
coord_store = gr.Textbox(visible=False)
mask_btn = gr.Button("生成分割掩码")
mask_output = gr.Image(label="分割结果")
@img_input.select(inputs=[], outputs=coord_store)
def store_coordinate(evt: gr.SelectData):
return f"{evt.index[0]},{evt.index[1]}"
mask_btn.click(
generate_mask,
inputs=[model_selector, img_input, coord_store],
outputs=mask_output
)
# 动态显示控制
mode.change(
lambda x: (gr.update(visible=x=="多图协同分割"), gr.update(visible=x=="点提示交互分割"), gr.update(visible=x=="框提示交互分割")),
inputs=mode,
outputs=[tab1, tab2, tab3]
)
demo.launch(debug=True)