import tqdm #import fastCNN import numpy as np import gradio as gr import os #os.system("sudo apt-get install nvIDia-cuda-toolkit") os.system("pip3 install torch") #os.system("/usr/local/bin/python -m pip install --upgrade pip") os.system("pip3 install collections") os.system("pip3 install torchvision") os.system("pip3 install einops") os.system("pip3 install opencv-python") aaaa=0 #os.system("pip3 install pydensecrf") #os.system("pip install argparse") #import pydensecrf.densecrf as dcrf from PIL import Image import torch import cv2 import torch.nn.functional as F from torchvision import transforms from model_video import build_model import numpy as np import collections def show_coord(evt: gr.SelectData): return f"{evt.index[0]},{evt.index[1]}" def generate_mask(model_type,img, coord): #x, y = map(int, coord.split(',')) # mask = sepia(model_type,(img*0.999999).astype(np.uint8),(img*0.999999).astype(np.uint8),(img*0.999999).astype(np.uint8),(img*0.999999).astype(np.uint8),(img*0.999999).astype(np.uint8), stack_image=False) mask = F.interpolate(torch.from_numpy(mask).unsqueeze(0).unsqueeze(0),size=[img.shape[0],img.shape[1]],mode='bilinear').squeeze().numpy() col = torch.from_numpy(mask).squeeze().unsqueeze(2).repeat(1,1,3) col=col/col.max() mask_torch=torch.from_numpy(mask).squeeze().unsqueeze(2).repeat(1,1,3) mask_torch=mask_torch/mask_torch.max() #col[:,:,0]=0 img=img/img.max()*255 col=col*255 col[:,:,0]=0 mix = (1-mask_torch)*img+mask_torch*img*0.5+mask_torch*col*0.5 return mix.numpy().astype(np.uint8)#overlay_mask(img, mask) def create_mode2_interface(): with gr.Blocks() as mode2: with gr.Column(): img_input = gr.Image( type="numpy", sources=["upload"], # 正确复数形式参数[2](@ref) label="点击上传图片并选择点", interactive=True ) # 坐标存储组件 coord_store = gr.Textbox(visible=False) # 绑定点击事件 @img_input.select(inputs=[], outputs=coord_store) def capture_coordinates(evt: gr.SelectData): return f"{evt.index[0]},{evt.index[1]}" # 修改3:正确绑定点击事件 @img_input.select(inputs=img_input, outputs=coord_store) def store_coordinate(evt: gr.SelectData): return f"{evt.index[0]},{evt.index[1]}" btn = gr.Button("生成分割掩码") mask_output = gr.Image(label="分割结果") btn.click( generate_mask, inputs=[img_input, coord_store], outputs=mask_output ) return mode2 def create_mode3_interface(): with gr.Blocks() as mode2: with gr.Column(): img_input = gr.Image( type="numpy", sources=["upload"], # 正确复数形式参数[2](@ref) label="点击上传图片并选择框", interactive=True ) # 坐标存储组件 coord_store = gr.Textbox(visible=False) # 绑定点击事件 @img_input.select(inputs=[], outputs=coord_store) def capture_coordinates(evt: gr.SelectData): return f"{evt.index[0]},{evt.index[1]}" # 修改3:正确绑定点击事件 @img_input.select(inputs=img_input, outputs=coord_store) def store_coordinate(evt: gr.SelectData): return f"{evt.index[0]},{evt.index[1]}" btn = gr.Button("生成分割掩码") mask_output = gr.Image(label="分割结果") btn.click( generate_mask, inputs=[img_input, coord_store], outputs=mask_output ) return mode2 #import argparse device='cpu' net = build_model(device).to(device) #net=torch.nn.DataParallel(net) model_path = 'image_best.pth' print(model_path) weight=torch.load(model_path,map_location=torch.device(device)) #print(type(weight)) new_dict=collections.OrderedDict() for k in weight.keys(): new_dict[k[len('module.'):]]=weight[k] net.load_state_dict(new_dict) net.eval() net = net.to(device) def test(gpu_id, net, img_list, group_size, img_size,stack_image=True): print('test') #device=device hl,wl=[_.shape[0] for _ in img_list],[_.shape[1] for _ in img_list] img_transform = transforms.Compose([transforms.Resize((img_size, img_size)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) img_transform_gray = transforms.Compose([transforms.Resize((img_size, img_size)), transforms.ToTensor(), transforms.Normalize(mean=[0.449], std=[0.226])]) with torch.no_grad(): group_img=torch.rand(5,3,224,224) for i in range(5): group_img[i]=img_transform(Image.fromarray(img_list[i])) _,pred_mask=net(group_img*1) pred_mask=(pred_mask.detach().squeeze()*255)#.numpy().astype(np.uint8) #pred_mask=[F.interpolate(pred_mask[i].reshape(1,1,pred_mask[i].shape[-2],pred_mask[i].shape[-1]),size=(size,size),mode='bilinear').squeeze().numpy().astype(np.uint8) for i in range(5)] img_resize=[((group_img[i]-group_img[i].min())/(group_img[i].max()-group_img[i].min())*255).permute(1,2,0).contiguous().numpy().astype(np.uint8) for i in range(5)] pred_mask=[(pred_mask[i].numpy().astype(np.uint8)) for i in range(5)]#[(img_resize[i],pred_mask[i].numpy().astype(np.uint8)) for i in range(5)] if not stack_image: return pred_mask[0] #for i in range(5): # print(img_list[i].shape,pred_mask[i].shape) #pred_mask=[crf_refine(img_list[i],pred_mask[i]) for i in range(5)] print(pred_mask[0].shape) white=(torch.ones(2,pred_mask[0].shape[1],3)*255).long() result = [torch.cat([torch.from_numpy(img_resize[i]),white,torch.from_numpy(pred_mask[i]).unsqueeze(2).repeat(1,1,3)],dim=0).numpy() for i in range(5)] #w, h = 224,224#Image.open(image_list[i][j]).size #result = result.resize((w, h), Image.BILINEAR) #result.convert('L').save('0.png') print('done') return result img_lst=[(torch.rand(352,352,3)*255).numpy().astype(np.uint8) for i in range(5)] #simly test res=test('cpu',net,img_lst,5,224) '''for i in range(5): assert res[i].shape[0]==352 and res[i].shape[1]==352 and res[i].shape[2]==3''' def sepia(model_type,img1,img2,img3,img4,img5,stack_image=True): print('sepia') print(img1.shape,img2.shape,img3.shape,img4.shape,img5.shape) '''ans=[] print(len(input_imgs)) for input_img in input_imgs: sepia_filter = np.array( [[0.393, 0.769, 0.189], [0.349, 0.686, 0.168], [0.272, 0.534, 0.131]] ) sepia_img = input_img.dot(sepia_filter.T) sepia_img /= sepia_img.max() ans.append(input_img)''' img_list=[img1,img2,img3,img4,img5] h_list,w_list=[_.shape[0] for _ in img_list],[_.shape[1] for _ in img_list] #print(type(img1)) #print(img1.shape) result_list=test(device,net,img_list,5,224,stack_image) if not stack_image: return result_list #result_list=[result_list[i].resize((w_list[i], h_list[i]), Image.BILINEAR) for i in range(5)] img1,img2,img3,img4,img5=result_list#test('cpu',net,img_list,5,224) white=(torch.ones(img1.shape[0],2,3)*255).numpy().astype(np.uint8) return np.concatenate([img1,white,img2,white,img3,white,img4,white,img5],axis=1) #gr.Image(shape=(224, 2)) #demo = gr.Interface(sepia, inputs=["image","image","image","image","image"], outputs=["image","image","image","image","image"])#gr.Interface(sepia, gr.Image(shape=(200, 200)), "image") #demo = gr.Interface(sepia, inputs=["image","image","image","image","image"], outputs=["image"]) #demo.launch(debug=True) #replace Interface with Blocks def create_mode1_interface(): with gr.Blocks() as demo: with gr.Row(): # 创建5列网格布局 with gr.Column(scale=1, min_width=150): input1 = gr.Image(label="image1", type="numpy") with gr.Column(scale=1, min_width=150): input2 = gr.Image(label="image2", type="numpy") with gr.Column(scale=1, min_width=150): input3 = gr.Image(label="image3", type="numpy") with gr.Column(scale=1, min_width=150): input4 = gr.Image(label="image4", type="numpy") with gr.Column(scale=1, min_width=150): input5 = gr.Image(label="image5", type="numpy") btn = gr.Button("start processing") with gr.Row(): output = gr.Image(label="output", type="numpy") #bind function btn.click( fn=sepia, inputs=[input1, input2, input3, input4, input5], outputs=output ) with gr.Blocks(title="交互式图像组分割系统") as demo: # 模式选择器 with gr.Row(): mode = gr.Radio( ["多图协同分割", "点提示交互分割","框提示交互分割"], value="多图协同分割", label="运行模式" ) model_selector = gr.Dropdown( choices=["RepViT-SAM", "EdgeSAM", "SAM-H"], value="SAM-H", label="选择模型", container=False # 去除默认容器边框 ) # 使用Tab容器替代独立Blocks with gr.Tabs() as mode_container: with gr.Tab("多图模式", id=0) as tab1: # 模式1界面组件 with gr.Row(): inputs = [gr.Image(type="numpy", label=f"图像{i+1}") for i in range(5)] process_btn = gr.Button("开始处理") output_img = gr.Image(label="处理结果") process_btn.click( sepia, inputs=[model_selector]+inputs, outputs=output_img ) with gr.Tab("点选交互模式", id=1) as tab2: # 模式2界面组件 img_input = gr.Image(type="numpy", label="点击上传图片并选择点") coord_store = gr.Textbox(visible=False) mask_btn = gr.Button("生成分割掩码") mask_output = gr.Image(label="分割结果") @img_input.select(inputs=[], outputs=coord_store) def store_coordinate(evt: gr.SelectData): return f"{evt.index[0]},{evt.index[1]}" mask_btn.click( generate_mask, inputs=[model_selector,img_input, coord_store], outputs=mask_output ) with gr.Tab("框选交互模式", id=2) as tab3: # 模式2界面组件 img_input = gr.Image(type="numpy", label="点击上传图片并选择框") coord_store = gr.Textbox(visible=False) mask_btn = gr.Button("生成分割掩码") mask_output = gr.Image(label="分割结果") @img_input.select(inputs=[], outputs=coord_store) def store_coordinate(evt: gr.SelectData): return f"{evt.index[0]},{evt.index[1]}" mask_btn.click( generate_mask, inputs=[model_selector, img_input, coord_store], outputs=mask_output ) # 动态显示控制 mode.change( lambda x: (gr.update(visible=x=="多图协同分割"), gr.update(visible=x=="点提示交互分割"), gr.update(visible=x=="框提示交互分割")), inputs=mode, outputs=[tab1, tab2, tab3] ) demo.launch(debug=True)