Spaces:
Running
Running
# coding:utf-8 | |
import os | |
import numpy as np | |
import cv2 | |
from typing import Optional | |
import torch | |
# from models.transforms import ResizeLongestSide | |
# from .transforms import ResizeLongestSide | |
from torchvision import transforms | |
def get_prompt_inp_scatter(scatter_file_): | |
scatter_mask = cv2.imread(scatter_file_, cv2.IMREAD_UNCHANGED) | |
return scatter_mask | |
def pre_scatter_prompt(scatter, filp, device): | |
if filp == True: | |
scatter = cv2.flip(scatter, 1) | |
img_transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) | |
scatter_torch = img_transform(scatter) | |
scatter_torch = scatter_torch.to(device) | |
return scatter_torch | |
def get_prompt_inp(txt_file_, filp): | |
f = open(txt_file_) | |
lines = f.readlines() | |
points = [] | |
labels = [] | |
boxes = [] | |
masks = [] | |
for line in lines: | |
x_1, y_1, x_2, y_2, x_3, y_3, x_4, y_4, classname, _ = line.split(' ') | |
# print(x_1, y_1, x_2, y_2, x_3, y_3, x_4, y_4, classname, _) | |
x_1, y_1, x_2, y_2, x_3, y_3, x_4, y_4 = float(x_1), float(y_1), \ | |
float(x_2), float(y_2), \ | |
float(x_3), float(y_3), \ | |
float(x_4), float(y_4) | |
xmin = min(x_1, x_2, x_3, x_4) | |
xmax = max(x_1, x_2, x_3, x_4) | |
ymin = min(y_1, y_2, y_3, y_4) | |
ymax = max(y_1, y_2, y_3, y_4) | |
if filp: | |
xmin = 1024.0 - xmin | |
xmax = 1024.0 - xmax | |
x_center = (xmin + xmax)/2 | |
y_center = (ymin + ymax)/2 | |
point = [x_center, y_center] | |
box = [[xmin, ymin], [xmax, ymax]] | |
# box = [xmin, ymin, xmax, ymax] | |
mask = [] | |
points.append(point) | |
labels.append(classname) | |
boxes.append(box) | |
masks.append(mask) | |
# boxes = boxes[:1] | |
# return points, labels, boxes, masks | |
return points, labels, boxes, None | |
def pre_prompt(points=None, boxes=None, masks=None, device=None): | |
points_torch = points | |
if points != None: | |
# points = points/16.0 | |
points_torch = torch.as_tensor(points, dtype=torch.float, device=device) | |
points_torch = points_torch/16.0 | |
boxes_torch = boxes | |
if boxes != None: | |
# boxes = boxes/16.0 | |
boxes_torch = torch.as_tensor(boxes, dtype=torch.float, device=device) | |
boxes_torch = boxes_torch/16.0 | |
# for box in boxes: | |
# left_top, bottom_right = box | |
masks_torch = masks | |
if masks != None: | |
masks_torch = torch.as_tensor(masks, dtype=torch.float, device=device) | |
return points_torch, boxes_torch, masks_torch | |
# def pre_prompt( | |
# point_coords: Optional[np.ndarray] = None, | |
# point_labels: Optional[np.ndarray] = None, | |
# box: Optional[np.ndarray] = None, | |
# mask_input: Optional[np.ndarray] = None, | |
# device=None, | |
# original_size = [1024, 1024] | |
# ): | |
# | |
# transform = ResizeLongestSide(1024) | |
# # Transform input prompts | |
# coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None | |
# if point_coords is not None: | |
# assert ( | |
# point_labels is not None | |
# ), "point_labels must be supplied if point_coords is supplied." | |
# point_coords = transform.apply_coords(point_coords, original_size) | |
# coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=device) | |
# labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=device) | |
# coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] | |
# if box is not None: | |
# box = transform.apply_boxes(box, original_size) | |
# box_torch = torch.as_tensor(box, dtype=torch.float, device=device) | |
# box_torch = box_torch[None, :] | |
# if mask_input is not None: | |
# mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=device) | |
# mask_input_torch = mask_input_torch[None, :, :, :] | |
# | |
# return coords_torch, labels_torch, box_torch, mask_input_torch | |
if __name__ == '__main__': | |
txt_dir = './ISAID/train/trainprompt/sub_labelTxt/' | |
txt_list = os.listdir(txt_dir) | |
txt_file_0 = txt_dir + txt_list[0] | |
points, labels, boxes, masks = get_prompt_inp(txt_file_0) | |
print(points) | |
print(labels) | |
print(boxes) | |
# boxes = boxes / 16.0 | |
boxes_torch = torch.as_tensor(boxes, dtype=torch.float) | |
boxes_torch = boxes_torch/16.0 | |
print(boxes_torch, boxes_torch.shape) | |