# -*- coding: utf-8 -*- import sys import io import requests import json import base64 from PIL import Image import numpy as np import gradio as gr import mmengine from mmengine import Config, get import argparse import os import cv2 import yaml import torch from torch.utils.data import DataLoader from tqdm import tqdm import datasets import models import numpy as np from torchvision import transforms from mmcv.runner import load_checkpoint import visual_utils from PIL import Image from models.utils_prompt import get_prompt_inp, pre_prompt, pre_scatter_prompt, get_prompt_inp_scatter device = torch.device("cpu") def batched_predict(model, inp, coord, bsize): with torch.no_grad(): model.gen_feat(inp) n = coord.shape[1] ql = 0 preds = [] while ql < n: qr = min(ql + bsize, n) pred = model.query_rgb(coord[:, ql: qr, :]) preds.append(pred) ql = qr pred = torch.cat(preds, dim=1) return pred, preds def tensor2PIL(tensor): toPIL = transforms.ToPILImage() return toPIL(tensor) def Decoder1_optical_instance(image_input): with open('configs/fine_tuning_one_decoder.yaml', 'r') as f: config = yaml.load(f, Loader=yaml.FullLoader) model = models.make(config['model']).cpu() sam_checkpoint = torch.load("./save/model_epoch_last.pth", map_location='cpu') model.load_state_dict(sam_checkpoint, strict=False) model.eval() # img = np.array(image_input).copy() label2color = visual_utils.Label2Color(cmap=visual_utils.color_map('Unify_double')) # image_input.save(f'./save/visual_fair1m/input_img.png', quality=5) img = transforms.Resize([1024, 1024])(image_input) transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229,0.224,0.225])]) input_img = transform(img) input_img = input_img.unsqueeze(0) image_embedding = model.image_encoder(input_img) # torch.Size([1, 256, 64, 64]) sparse_embeddings, dense_embeddings, scatter_embeddings = model.prompt_encoder( points=None, boxes=None, masks=None, scatter=None) # 目标类预测decoder low_res_masks, iou_predictions = model.mask_decoder( image_embeddings=image_embedding, image_pe=model.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=False ) pred = model.postprocess_masks(low_res_masks, model.inp_size, model.inp_size) _, prediction = pred.max(dim=1) prediction_to_save = label2color(prediction.cpu().numpy().astype(np.uint8))[0] return prediction_to_save def Decoder1_optical_terrain(image_input): with open('configs/fine_tuning_one_decoder.yaml', 'r') as f: config = yaml.load(f, Loader=yaml.FullLoader) model = models.make(config['model']).cpu() sam_checkpoint = torch.load("./save/model_epoch_last.pth", map_location='cpu') model.load_state_dict(sam_checkpoint, strict=False) model.eval() denorm = visual_utils.Denormalize(mean=[0.485, 0.456, 0.406],std=[0.229,0.224,0.225]) label2color = visual_utils.Label2Color(cmap=visual_utils.color_map('Unify_Vai')) # image_input.save(f'./save/visual_fair1m/input_img.png', quality=5) img = transforms.Resize([1024, 1024])(image_input) transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229,0.224,0.225])]) input_img = transform(img) input_img = torch.unsqueeze(input_img, dim=0) # input_img = transforms.ToTensor()(img).unsqueeze(0) image_embedding = model.image_encoder(input_img) # torch.Size([1, 256, 64, 64]) sparse_embeddings, dense_embeddings, scatter_embeddings = model.prompt_encoder( points=None, boxes=None, masks=None, scatter=None) low_res_masks_instanse, iou_predictions = model.mask_decoder( image_embeddings=image_embedding, # image_embeddings=image_embedding.unsqueeze(0), image_pe=model.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, # multimask_output=multimask_output, multimask_output=False ) # 地物类预测decoder low_res_masks, iou_predictions_2 = model.mask_decoder_diwu( image_embeddings=image_embedding, image_pe=model.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, # multimask_output=False, multimask_output=True, ) # B*C+1*H*W pred_instance = model.postprocess_masks(low_res_masks_instanse, model.inp_size, model.inp_size) pred = model.postprocess_masks(low_res_masks, model.inp_size, model.inp_size) pred = torch.softmax(pred,dim=1) pred_instance = torch.softmax(pred_instance,dim=1) _, prediction = pred.max(dim=1) prediction[prediction==12]=0 #把第二个decoder里得背景变成0 print(torch.unique(prediction)) _, prediction_instance = pred_instance.max(dim=1) print(torch.unique(prediction_instance)) prediction_sum = prediction + prediction_instance #没有冲突的位置就会正常猜测 print(torch.unique(prediction_sum)) prediction_tmp = prediction_sum.clone() prediction_tmp[prediction_tmp==1] = 255 prediction_tmp[prediction_tmp==2] = 255 prediction_tmp[prediction_tmp==5] = 255 prediction_tmp[prediction_tmp==6] = 255 prediction_tmp[prediction_tmp==14] = 255 # prediction_tmp[prediction_tmp==0] = 255 #同时是背景 # index = prediction_tmp != 255 pred[:, 0][prediction_tmp == 255]=100 #把已经决定的像素位置的背景预测概率设置为最大 pred_instance[:, 0][prediction_tmp == 255]=100#把已经决定的像素位置的背景预测概率设置为最大 buchong = torch.zeros([1,2,1024,1024]) pred = torch.cat((pred, buchong),dim=1) # print(torch.unique(torch.argmax(pred,dim=1))) # Decoder1_logits = torch.zeros([1,15,1024,1024]).cuda() Decoder2_logits = torch.zeros([1,15,1024,1024]) Decoder2_logits[:,0,...] = pred[:,0,...] Decoder2_logits[:,5,...] = pred_instance[:,5,...] Decoder2_logits[:,14,...] = pred_instance[:,14,...] Decoder2_logits[:,1,...] = pred[:,1,...] Decoder2_logits[:,2,...] = pred[:,2,...] Decoder2_logits[:,6,...] = pred[:,6,...] # Decoder_logits = Decoder1_logits+Decoder2_logits pred_chongtu = torch.argmax(Decoder2_logits, dim=1) # pred_pred = torch.argmax(Decoder1_logits, dim=1) pred_predinstance = torch.argmax(Decoder2_logits, dim=1) print(torch.unique(pred_chongtu)) pred_chongtu[prediction_tmp == 255] = 0 prediction_sum[prediction_tmp!=255] = 0 prediction_final = (pred_chongtu + prediction_sum).cpu().numpy() prediction_to_save = label2color(prediction_final)[0] return prediction_to_save def Multi_box_prompts(input_prompt): with open('configs/fine_tuning_one_decoder.yaml', 'r') as f: config = yaml.load(f, Loader=yaml.FullLoader) model = models.make(config['model']).cpu() sam_checkpoint = torch.load("./save/model_epoch_last.pth", map_location='cpu') model.load_state_dict(sam_checkpoint, strict=False) model.eval() label2color = visual_utils.Label2Color(cmap=visual_utils.color_map('Unify_double')) # image_input.save(f'./save/visual_fair1m/input_img.png', quality=5) img = transforms.Resize([1024, 1024])(input_prompt["image"]) input_img = transforms.ToTensor()(img).unsqueeze(0) image_embedding = model.image_encoder(input_img) # torch.Size([1, 256, 64, 64]) sparse_embeddings, dense_embeddings, scatter_embeddings = model.prompt_encoder( points=None, boxes=None, masks=None, scatter=None) # 目标类预测decoder low_res_masks, iou_predictions = model.mask_decoder( image_embeddings=image_embedding, image_pe=model.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=False ) pred = model.postprocess_masks(low_res_masks, model.inp_size, model.inp_size) _, prediction = pred.max(dim=1) prediction_to_save = label2color(prediction.cpu().numpy().astype(np.uint8))[0] def find_instance(image_map): BACKGROUND = 0 steps = [[1, 0], [0, 1], [-1, 0], [0, -1], [1, 1], [1, -1], [-1, 1], [-1, -1]] instances = [] def bfs(x, y, category_id): nonlocal image_map, steps instance = {(x, y)} q = [(x, y)] image_map[x, y] = BACKGROUND while len(q) > 0: x, y = q.pop(0) # print(x, y, image_map[x][y]) for step in steps: xx = step[0] + x yy = step[1] + y if 0 <= xx < len(image_map) and 0 <= yy < len(image_map[0]) \ and image_map[xx][yy] == category_id: # and (xx, yy) not in q: q.append((xx, yy)) instance.add((xx, yy)) image_map[xx, yy] = BACKGROUND return instance image_map = image_map[:] for i in range(len(image_map)): for j in range(len(image_map[i])): category_id = image_map[i][j] if category_id == BACKGROUND: continue instances.append(bfs(i, j, category_id)) return instances prompts = find_instance(np.uint8(np.array(input_prompt["mask"]).sum(-1) != 0)) img_mask = np.array(img).copy() def get_box(prompt): xs = [] ys = [] for x, y in prompt: xs.append(x) ys.append(y) return [[min(xs), min(ys)], [max(xs), max(ys)]] def in_box(point, box): left_up, right_down = box x, y = point return x >= left_up[0] and x <= right_down[0] and y >= left_up[1] and y <= right_down[1] def draw_box(box_outer, img, radius=4): radius -= 1 left_up_outer, right_down_outer = box_outer box_inner = [list(np.array(left_up_outer) + radius), list(np.array(right_down_outer) - radius)] for x in range(len(img)): for y in range(len(img[x])): if in_box([x, y], box_outer): img_mask[x, y] = (1, 1, 1) if in_box([x, y], box_outer) and (not in_box([x, y], box_inner)): img[x, y] = (255, 0, 0) return img for prompt in prompts: box = get_box(prompt) output = draw_box(box, prediction_to_save) * (img_mask==1) return output def Decoder2_SAR(SAR_image, SAR_prompt): with open('configs/multi_mo_multi_task_sar_prompt.yaml', 'r') as f: config = yaml.load(f, Loader=yaml.FullLoader) model = models.make(config['model']).cpu() sam_checkpoint = torch.load("./save/SAR/model_epoch_last.pth", map_location='cpu') model.load_state_dict(sam_checkpoint, strict=True) model.eval() denorm = visual_utils.Denormalize(mean=[0.485, 0.456, 0.406],std=[0.229,0.224,0.225]) label2color = visual_utils.Label2Color(cmap=visual_utils.color_map('Unify_YIJISAR')) img = transforms.Resize([1024, 1024])(SAR_image) transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229,0.224,0.225])]) input_img = transform(img) input_img = torch.unsqueeze(input_img, dim=0) # input_img = transforms.ToTensor()(img).unsqueeze(0) # input_img = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229,0.224,0.225]) filp_flag = torch.Tensor([False]) image_embedding = model.image_encoder(input_img) # scattter_prompt = cv2.imread(scatter_file_, cv2.IMREAD_UNCHANGED) # scattter_prompt = get_prompt_inp_scatter(name[0].replace('gt', 'JIHUAFENJIE')) SAR_prompt = cv2.imread(SAR_prompt, cv2.IMREAD_UNCHANGED) scatter_torch = pre_scatter_prompt(SAR_prompt, filp_flag, device=input_img.device) scatter_torch = scatter_torch.unsqueeze(0) scatter_torch = torch.nn.functional.interpolate(scatter_torch, size=(256, 256)) sparse_embeddings, dense_embeddings, scatter_embeddings = model.prompt_encoder( points=None, boxes=None, masks=None, scatter=scatter_torch) # 地物类预测decoder low_res_masks, iou_predictions_2 = model.mask_decoder_diwu( image_embeddings=image_embedding, image_pe=model.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, # multimask_output=False, multimask_output=True, ) # B*C+1*H*W pred = model.postprocess_masks(low_res_masks, model.inp_size, model.inp_size) _, prediction = pred.max(dim=1) prediction = prediction.cpu().numpy() prediction_to_save = label2color(prediction)[0] return prediction_to_save examples1_instance = [ ['./images/optical/isaid/_P0007_1065_319_image.png'], ['./images/optical/isaid/_P0466_1068_420_image.png'], ['./images/optical/isaid/_P0897_146_34_image.png'], ['./images/optical/isaid/_P1397_844_904_image.png'], ['./images/optical/isaid/_P2645_883_965_image.png'], ['./images/optical/isaid/_P1398_1290_630_image.png'] ] examples1_terrain = [ ['./images/optical/vaihingen/top_mosaic_09cm_area2_105_image.png'], ['./images/optical/vaihingen/top_mosaic_09cm_area4_227_image.png'], ['./images/optical/vaihingen/top_mosaic_09cm_area20_142_image.png'], ['./images/optical/vaihingen/top_mosaic_09cm_area24_128_image.png'], ['./images/optical/vaihingen/top_mosaic_09cm_area27_34_image.png'] ] examples1_multi_box = [ ['./images/optical/isaid/_P0007_1065_319_image.png'], ['./images/optical/isaid/_P0466_1068_420_image.png'], ['./images/optical/isaid/_P0897_146_34_image.png'], ['./images/optical/isaid/_P1397_844_904_image.png'], ['./images/optical/isaid/_P2645_883_965_image.png'], ['./images/optical/isaid/_P1398_1290_630_image.png'] ] examples2 = [ ['./images/sar/YIJISARGF3_MYN_QPSI_001269_E113.2_N23.0_20161105_L1A_L10002009158_ampl_4_image.png', './images/sar/YIJISARGF3_MYN_QPSI_001269_E113.2_N23.0_20161105_L1A_L10002009158_ampl_4.png'], ['./images/sar/YIJISARGF3_MYN_QPSI_001269_E113.2_N23.0_20161105_L1A_L10002009158_ampl_15_image.png', './images/sar/YIJISARGF3_MYN_QPSI_001269_E113.2_N23.0_20161105_L1A_L10002009158_ampl_15.png'], ['./images/sar/YIJISARGF3_MYN_QPSI_001269_E113.2_N23.0_20161105_L1A_L10002009158_ampl_24_image.png', './images/sar/YIJISARGF3_MYN_QPSI_001269_E113.2_N23.0_20161105_L1A_L10002009158_ampl_24.png'], ['./images/sar/YIJISARGF3_MYN_QPSI_001269_E113.2_N23.0_20161105_L1A_L10002009158_ampl_41_image.png', './images/sar/YIJISARGF3_MYN_QPSI_001269_E113.2_N23.0_20161105_L1A_L10002009158_ampl_41.png'], ['./images/sar/YIJISARGF3_MYN_QPSI_999996_E121.2_N30.3_20160815_L1A_L10002015572_ampl_150_image.png', './images/sar/YIJISARGF3_MYN_QPSI_999996_E121.2_N30.3_20160815_L1A_L10002015572_ampl_150.png'] ] # RingMo-SAM designs two new promptable forms based on the characteristics of multimodal remote sensing images: # multi-boxes prompt and SAR polarization scatter prompt. title = "RingMo-SAM:A Foundation Model for Segment Anything in Multimodal Remote Sensing Images
\
\

[paper] \
\ \

RingMo-SAM can not only segment anything in optical and SAR remote sensing data, but also identify object categories.

\

\ " # [code] \ # with gr.Blocks() as demo: # image_input = gr.Image(type='pil', label='Input Img') # image_output = gr.Image(label='Segment Result', type='numpy') Decoder_optical_instance_io = gr.Interface(fn=Decoder1_optical_instance, inputs=[gr.Image(type='pil', label='optical_instance_img(光学图像)')], outputs=[gr.Image(label='segment_result', type='numpy')], # title=title, description="

\ Instance_Decoder:
\ Instance-type objects (such as vehicle, aircraft, ship, etc.) have a smaller proportion.
\ Our decoder can decouple the SAM's mask decoder into instance category decoder and terrain category decoder to ensure that the model fits adequately to both types of data.
\ Choose an example below, or, upload optical instance images to be tested.
\ Examples below were never trained and are randomly selected for testing in the wild.
\

", allow_flagging='auto', examples=examples1_instance, cache_examples=False, ) Decoder_optical_terrain_io = gr.Interface(fn=Decoder1_optical_terrain, inputs=[gr.Image(type='pil', label='optical_terrain_img(光学图像)')], # inputs=[gr.Image(type='pil', label='optical_img(光学图像)'), gr.Image(type='pil', label='SAR_img(SAR图像)'), gr.Image(type='pil', label='SAR_prompt(偏振散射提示)')], outputs=[gr.Image(label='segment_result', type='numpy')], # title=title, description="

\ Terrain_Decoder:
\ Terrain-type objects (such as vegetation, land, river, etc.) have a larger proportion.
\ Our decoder can decouple the SAM's mask decoder into instance category decoder and terrain category decoder to ensure that the model fits adequately to both types of data.
\ Choose an example below, or, upload optical terrain images to be tested.
\ Examples below were never trained and are randomly selected for testing in the wild.
\

", allow_flagging='auto', examples=examples1_terrain, cache_examples=False, ) Decoder_multi_box_prompts_io = gr.Interface(fn=Multi_box_prompts, inputs=[gr.ImageMask(brush_radius=4, type='pil', label='input_img(图像)')], outputs=[gr.Image(label='segment_result', type='numpy')], # title=title, description="

\ Multi-box Prompts:
\ Multiple boxes are sequentially encoded as concated sparse high-dimensional feature embedding, \ the corresponding multiple high-dimensional features are concated together into a high-dimensional feature vector as part of the sparse embedding.
\ Choose an example below, or, upload images to be tested, and then draw multi-boxes.
\ Examples below were never trained and are randomly selected for testing in the wild.
\

", allow_flagging='auto', examples=examples1_multi_box, cache_examples=False, ) Decoder_SAR_io = gr.Interface(fn=Decoder2_SAR, inputs=[gr.Image(type='pil', label='SAR_img(SAR图像)'), gr.Image(type='filepath', label='SAR_prompt(偏振散射提示)')], outputs=[gr.Image(label='segment_result', type='numpy')], description="

\ SAR Polarization Scatter Prompts:
\ Different terrain categories usually exhibit different scattering properties. \ Therefore, we code network for coded mapping of these SAR polarization scatter prompts to the corresponding SAR images, \ which improves the segmentation results of SAR images.
\ Choose an example below, or, upload SAR images and the corresponding polarization scatter prompts to be tested.
\ Examples below were never trained and are randomly selected for testing in the wild.
\

", allow_flagging='auto', examples=examples2, cache_examples=False, ) # Decoder1_io.launch(server_name="0.0.0.0", server_port=34311) # Decoder1_io.launch(enable_queue=False) # demo = gr.TabbedInterface([Decoder1_io, Decoder2_io], ['Instance_Decoder', 'Terrain_Decoder'], title=title) demo = gr.TabbedInterface([Decoder_optical_instance_io, Decoder_optical_terrain_io, Decoder_multi_box_prompts_io, Decoder_SAR_io], ['optical_instance_img(光学图像)', 'optical_terrain_img(光学图像)', 'multi_box_prompts(多框提示)', 'SAR_img(偏振散射提示)'], title=title).launch() # -