RingMo-SAM / models /utils_prompt.py
AI-Cyber's picture
Upload 123 files
8d7921b
raw
history blame contribute delete
4.64 kB
# coding:utf-8
import os
import numpy as np
import cv2
from typing import Optional
import torch
# from models.transforms import ResizeLongestSide
# from .transforms import ResizeLongestSide
from torchvision import transforms
def get_prompt_inp_scatter(scatter_file_):
scatter_mask = cv2.imread(scatter_file_, cv2.IMREAD_UNCHANGED)
return scatter_mask
def pre_scatter_prompt(scatter, filp, device):
if filp == True:
scatter = cv2.flip(scatter, 1)
img_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
scatter_torch = img_transform(scatter)
scatter_torch = scatter_torch.to(device)
return scatter_torch
def get_prompt_inp(txt_file_, filp):
f = open(txt_file_)
lines = f.readlines()
points = []
labels = []
boxes = []
masks = []
for line in lines:
x_1, y_1, x_2, y_2, x_3, y_3, x_4, y_4, classname, _ = line.split(' ')
# print(x_1, y_1, x_2, y_2, x_3, y_3, x_4, y_4, classname, _)
x_1, y_1, x_2, y_2, x_3, y_3, x_4, y_4 = float(x_1), float(y_1), \
float(x_2), float(y_2), \
float(x_3), float(y_3), \
float(x_4), float(y_4)
xmin = min(x_1, x_2, x_3, x_4)
xmax = max(x_1, x_2, x_3, x_4)
ymin = min(y_1, y_2, y_3, y_4)
ymax = max(y_1, y_2, y_3, y_4)
if filp:
xmin = 1024.0 - xmin
xmax = 1024.0 - xmax
x_center = (xmin + xmax)/2
y_center = (ymin + ymax)/2
point = [x_center, y_center]
box = [[xmin, ymin], [xmax, ymax]]
# box = [xmin, ymin, xmax, ymax]
mask = []
points.append(point)
labels.append(classname)
boxes.append(box)
masks.append(mask)
# boxes = boxes[:1]
# return points, labels, boxes, masks
return points, labels, boxes, None
def pre_prompt(points=None, boxes=None, masks=None, device=None):
points_torch = points
if points != None:
# points = points/16.0
points_torch = torch.as_tensor(points, dtype=torch.float, device=device)
points_torch = points_torch/16.0
boxes_torch = boxes
if boxes != None:
# boxes = boxes/16.0
boxes_torch = torch.as_tensor(boxes, dtype=torch.float, device=device)
boxes_torch = boxes_torch/16.0
# for box in boxes:
# left_top, bottom_right = box
masks_torch = masks
if masks != None:
masks_torch = torch.as_tensor(masks, dtype=torch.float, device=device)
return points_torch, boxes_torch, masks_torch
# def pre_prompt(
# point_coords: Optional[np.ndarray] = None,
# point_labels: Optional[np.ndarray] = None,
# box: Optional[np.ndarray] = None,
# mask_input: Optional[np.ndarray] = None,
# device=None,
# original_size = [1024, 1024]
# ):
#
# transform = ResizeLongestSide(1024)
# # Transform input prompts
# coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
# if point_coords is not None:
# assert (
# point_labels is not None
# ), "point_labels must be supplied if point_coords is supplied."
# point_coords = transform.apply_coords(point_coords, original_size)
# coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=device)
# labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=device)
# coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
# if box is not None:
# box = transform.apply_boxes(box, original_size)
# box_torch = torch.as_tensor(box, dtype=torch.float, device=device)
# box_torch = box_torch[None, :]
# if mask_input is not None:
# mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=device)
# mask_input_torch = mask_input_torch[None, :, :, :]
#
# return coords_torch, labels_torch, box_torch, mask_input_torch
if __name__ == '__main__':
txt_dir = './ISAID/train/trainprompt/sub_labelTxt/'
txt_list = os.listdir(txt_dir)
txt_file_0 = txt_dir + txt_list[0]
points, labels, boxes, masks = get_prompt_inp(txt_file_0)
print(points)
print(labels)
print(boxes)
# boxes = boxes / 16.0
boxes_torch = torch.as_tensor(boxes, dtype=torch.float)
boxes_torch = boxes_torch/16.0
print(boxes_torch, boxes_torch.shape)