Spaces:
Sleeping
Sleeping
import tqdm | |
#import fastCNN | |
import numpy as np | |
import gradio as gr | |
import os | |
#os.system("sudo apt-get install nvIDia-cuda-toolkit") | |
os.system("pip3 install torch") | |
#os.system("/usr/local/bin/python -m pip install --upgrade pip") | |
os.system("pip3 install collections") | |
os.system("pip3 install torchvision") | |
os.system("pip3 install einops") | |
os.system("pip3 install opencv-python") | |
aaaa=0 | |
#os.system("pip3 install pydensecrf") | |
#os.system("pip install argparse") | |
#import pydensecrf.densecrf as dcrf | |
from PIL import Image | |
import torch | |
import cv2 | |
import torch.nn.functional as F | |
from torchvision import transforms | |
from model_video import build_model | |
import numpy as np | |
import collections | |
def show_coord(evt: gr.SelectData): | |
return f"{evt.index[0]},{evt.index[1]}" | |
def generate_mask(model_type,img, coord): | |
#x, y = map(int, coord.split(',')) | |
# | |
mask = sepia(model_type,(img*0.999999).astype(np.uint8),(img*0.999999).astype(np.uint8),(img*0.999999).astype(np.uint8),(img*0.999999).astype(np.uint8),(img*0.999999).astype(np.uint8), stack_image=False) | |
mask = F.interpolate(torch.from_numpy(mask).unsqueeze(0).unsqueeze(0),size=[img.shape[0],img.shape[1]],mode='bilinear').squeeze().numpy() | |
col = torch.from_numpy(mask).squeeze().unsqueeze(2).repeat(1,1,3) | |
col=col/col.max() | |
mask_torch=torch.from_numpy(mask).squeeze().unsqueeze(2).repeat(1,1,3) | |
mask_torch=mask_torch/mask_torch.max() | |
#col[:,:,0]=0 | |
img=img/img.max()*255 | |
col=col*255 | |
col[:,:,0]=0 | |
mix = (1-mask_torch)*img+mask_torch*img*0.5+mask_torch*col*0.5 | |
return mix.numpy().astype(np.uint8)#overlay_mask(img, mask) | |
def create_mode2_interface(): | |
with gr.Blocks() as mode2: | |
with gr.Column(): | |
img_input = gr.Image( | |
type="numpy", | |
sources=["upload"], # 正确复数形式参数[2](@ref) | |
label="点击上传图片并选择点", | |
interactive=True | |
) | |
# 坐标存储组件 | |
coord_store = gr.Textbox(visible=False) | |
# 绑定点击事件 | |
def capture_coordinates(evt: gr.SelectData): | |
return f"{evt.index[0]},{evt.index[1]}" | |
# 修改3:正确绑定点击事件 | |
def store_coordinate(evt: gr.SelectData): | |
return f"{evt.index[0]},{evt.index[1]}" | |
btn = gr.Button("生成分割掩码") | |
mask_output = gr.Image(label="分割结果") | |
btn.click( | |
generate_mask, | |
inputs=[img_input, coord_store], | |
outputs=mask_output | |
) | |
return mode2 | |
def create_mode3_interface(): | |
with gr.Blocks() as mode2: | |
with gr.Column(): | |
img_input = gr.Image( | |
type="numpy", | |
sources=["upload"], # 正确复数形式参数[2](@ref) | |
label="点击上传图片并选择框", | |
interactive=True | |
) | |
# 坐标存储组件 | |
coord_store = gr.Textbox(visible=False) | |
# 绑定点击事件 | |
def capture_coordinates(evt: gr.SelectData): | |
return f"{evt.index[0]},{evt.index[1]}" | |
# 修改3:正确绑定点击事件 | |
def store_coordinate(evt: gr.SelectData): | |
return f"{evt.index[0]},{evt.index[1]}" | |
btn = gr.Button("生成分割掩码") | |
mask_output = gr.Image(label="分割结果") | |
btn.click( | |
generate_mask, | |
inputs=[img_input, coord_store], | |
outputs=mask_output | |
) | |
return mode2 | |
#import argparse | |
device='cpu' | |
net = build_model(device).to(device) | |
#net=torch.nn.DataParallel(net) | |
model_path = 'image_best.pth' | |
print(model_path) | |
weight=torch.load(model_path,map_location=torch.device(device)) | |
#print(type(weight)) | |
new_dict=collections.OrderedDict() | |
for k in weight.keys(): | |
new_dict[k[len('module.'):]]=weight[k] | |
net.load_state_dict(new_dict) | |
net.eval() | |
net = net.to(device) | |
def test(gpu_id, net, img_list, group_size, img_size,stack_image=True): | |
print('test') | |
#device=device | |
hl,wl=[_.shape[0] for _ in img_list],[_.shape[1] for _ in img_list] | |
img_transform = transforms.Compose([transforms.Resize((img_size, img_size)), transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) | |
img_transform_gray = transforms.Compose([transforms.Resize((img_size, img_size)), transforms.ToTensor(), | |
transforms.Normalize(mean=[0.449], std=[0.226])]) | |
with torch.no_grad(): | |
group_img=torch.rand(5,3,224,224) | |
for i in range(5): | |
group_img[i]=img_transform(Image.fromarray(img_list[i])) | |
_,pred_mask=net(group_img*1) | |
pred_mask=(pred_mask.detach().squeeze()*255)#.numpy().astype(np.uint8) | |
#pred_mask=[F.interpolate(pred_mask[i].reshape(1,1,pred_mask[i].shape[-2],pred_mask[i].shape[-1]),size=(size,size),mode='bilinear').squeeze().numpy().astype(np.uint8) for i in range(5)] | |
img_resize=[((group_img[i]-group_img[i].min())/(group_img[i].max()-group_img[i].min())*255).permute(1,2,0).contiguous().numpy().astype(np.uint8) | |
for i in range(5)] | |
pred_mask=[(pred_mask[i].numpy().astype(np.uint8)) for i in range(5)]#[(img_resize[i],pred_mask[i].numpy().astype(np.uint8)) for i in range(5)] | |
if not stack_image: | |
return pred_mask[0] | |
#for i in range(5): | |
# print(img_list[i].shape,pred_mask[i].shape) | |
#pred_mask=[crf_refine(img_list[i],pred_mask[i]) for i in range(5)] | |
print(pred_mask[0].shape) | |
white=(torch.ones(2,pred_mask[0].shape[1],3)*255).long() | |
result = [torch.cat([torch.from_numpy(img_resize[i]),white,torch.from_numpy(pred_mask[i]).unsqueeze(2).repeat(1,1,3)],dim=0).numpy() for i in range(5)] | |
#w, h = 224,224#Image.open(image_list[i][j]).size | |
#result = result.resize((w, h), Image.BILINEAR) | |
#result.convert('L').save('0.png') | |
print('done') | |
return result | |
img_lst=[(torch.rand(352,352,3)*255).numpy().astype(np.uint8) for i in range(5)] | |
#simly test | |
res=test('cpu',net,img_lst,5,224) | |
'''for i in range(5): | |
assert res[i].shape[0]==352 and res[i].shape[1]==352 and res[i].shape[2]==3''' | |
def sepia(model_type,img1,img2,img3,img4,img5,stack_image=True): | |
print('sepia') | |
print(img1.shape,img2.shape,img3.shape,img4.shape,img5.shape) | |
'''ans=[] | |
print(len(input_imgs)) | |
for input_img in input_imgs: | |
sepia_filter = np.array( | |
[[0.393, 0.769, 0.189], [0.349, 0.686, 0.168], [0.272, 0.534, 0.131]] | |
) | |
sepia_img = input_img.dot(sepia_filter.T) | |
sepia_img /= sepia_img.max() | |
ans.append(input_img)''' | |
img_list=[img1,img2,img3,img4,img5] | |
h_list,w_list=[_.shape[0] for _ in img_list],[_.shape[1] for _ in img_list] | |
#print(type(img1)) | |
#print(img1.shape) | |
result_list=test(device,net,img_list,5,224,stack_image) | |
if not stack_image: | |
return result_list | |
#result_list=[result_list[i].resize((w_list[i], h_list[i]), Image.BILINEAR) for i in range(5)] | |
img1,img2,img3,img4,img5=result_list#test('cpu',net,img_list,5,224) | |
white=(torch.ones(img1.shape[0],2,3)*255).numpy().astype(np.uint8) | |
return np.concatenate([img1,white,img2,white,img3,white,img4,white,img5],axis=1) | |
#gr.Image(shape=(224, 2)) | |
#demo = gr.Interface(sepia, inputs=["image","image","image","image","image"], outputs=["image","image","image","image","image"])#gr.Interface(sepia, gr.Image(shape=(200, 200)), "image") | |
#demo = gr.Interface(sepia, inputs=["image","image","image","image","image"], outputs=["image"]) | |
#demo.launch(debug=True) | |
#replace Interface with Blocks | |
def create_mode1_interface(): | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
# 创建5列网格布局 | |
with gr.Column(scale=1, min_width=150): | |
input1 = gr.Image(label="image1", type="numpy") | |
with gr.Column(scale=1, min_width=150): | |
input2 = gr.Image(label="image2", type="numpy") | |
with gr.Column(scale=1, min_width=150): | |
input3 = gr.Image(label="image3", type="numpy") | |
with gr.Column(scale=1, min_width=150): | |
input4 = gr.Image(label="image4", type="numpy") | |
with gr.Column(scale=1, min_width=150): | |
input5 = gr.Image(label="image5", type="numpy") | |
btn = gr.Button("start processing") | |
with gr.Row(): | |
output = gr.Image(label="output", type="numpy") | |
#bind function | |
btn.click( | |
fn=sepia, | |
inputs=[input1, input2, input3, input4, input5], | |
outputs=output | |
) | |
with gr.Blocks(title="交互式图像组分割系统") as demo: | |
# 模式选择器 | |
with gr.Row(): | |
mode = gr.Radio( | |
["多图协同分割", "点提示交互分割","框提示交互分割"], | |
value="多图协同分割", | |
label="运行模式" | |
) | |
model_selector = gr.Dropdown( | |
choices=["RepViT-SAM", "EdgeSAM", "SAM-H"], | |
value="SAM-H", | |
label="选择模型", | |
container=False # 去除默认容器边框 | |
) | |
# 使用Tab容器替代独立Blocks | |
with gr.Tabs() as mode_container: | |
with gr.Tab("多图模式", id=0) as tab1: | |
# 模式1界面组件 | |
with gr.Row(): | |
inputs = [gr.Image(type="numpy", label=f"图像{i+1}") for i in range(5)] | |
process_btn = gr.Button("开始处理") | |
output_img = gr.Image(label="处理结果") | |
process_btn.click( | |
sepia, | |
inputs=[model_selector]+inputs, | |
outputs=output_img | |
) | |
with gr.Tab("点选交互模式", id=1) as tab2: | |
# 模式2界面组件 | |
img_input = gr.Image(type="numpy", label="点击上传图片并选择点") | |
coord_store = gr.Textbox(visible=False) | |
mask_btn = gr.Button("生成分割掩码") | |
mask_output = gr.Image(label="分割结果") | |
def store_coordinate(evt: gr.SelectData): | |
return f"{evt.index[0]},{evt.index[1]}" | |
mask_btn.click( | |
generate_mask, | |
inputs=[model_selector,img_input, coord_store], | |
outputs=mask_output | |
) | |
with gr.Tab("框选交互模式", id=2) as tab3: | |
# 模式2界面组件 | |
img_input = gr.Image(type="numpy", label="点击上传图片并选择框") | |
coord_store = gr.Textbox(visible=False) | |
mask_btn = gr.Button("生成分割掩码") | |
mask_output = gr.Image(label="分割结果") | |
def store_coordinate(evt: gr.SelectData): | |
return f"{evt.index[0]},{evt.index[1]}" | |
mask_btn.click( | |
generate_mask, | |
inputs=[model_selector, img_input, coord_store], | |
outputs=mask_output | |
) | |
# 动态显示控制 | |
mode.change( | |
lambda x: (gr.update(visible=x=="多图协同分割"), gr.update(visible=x=="点提示交互分割"), gr.update(visible=x=="框提示交互分割")), | |
inputs=mode, | |
outputs=[tab1, tab2, tab3] | |
) | |
demo.launch(debug=True) | |