Spaces:
Running
Running
File size: 4,643 Bytes
8d7921b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
# coding:utf-8
import os
import numpy as np
import cv2
from typing import Optional
import torch
# from models.transforms import ResizeLongestSide
# from .transforms import ResizeLongestSide
from torchvision import transforms
def get_prompt_inp_scatter(scatter_file_):
scatter_mask = cv2.imread(scatter_file_, cv2.IMREAD_UNCHANGED)
return scatter_mask
def pre_scatter_prompt(scatter, filp, device):
if filp == True:
scatter = cv2.flip(scatter, 1)
img_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
scatter_torch = img_transform(scatter)
scatter_torch = scatter_torch.to(device)
return scatter_torch
def get_prompt_inp(txt_file_, filp):
f = open(txt_file_)
lines = f.readlines()
points = []
labels = []
boxes = []
masks = []
for line in lines:
x_1, y_1, x_2, y_2, x_3, y_3, x_4, y_4, classname, _ = line.split(' ')
# print(x_1, y_1, x_2, y_2, x_3, y_3, x_4, y_4, classname, _)
x_1, y_1, x_2, y_2, x_3, y_3, x_4, y_4 = float(x_1), float(y_1), \
float(x_2), float(y_2), \
float(x_3), float(y_3), \
float(x_4), float(y_4)
xmin = min(x_1, x_2, x_3, x_4)
xmax = max(x_1, x_2, x_3, x_4)
ymin = min(y_1, y_2, y_3, y_4)
ymax = max(y_1, y_2, y_3, y_4)
if filp:
xmin = 1024.0 - xmin
xmax = 1024.0 - xmax
x_center = (xmin + xmax)/2
y_center = (ymin + ymax)/2
point = [x_center, y_center]
box = [[xmin, ymin], [xmax, ymax]]
# box = [xmin, ymin, xmax, ymax]
mask = []
points.append(point)
labels.append(classname)
boxes.append(box)
masks.append(mask)
# boxes = boxes[:1]
# return points, labels, boxes, masks
return points, labels, boxes, None
def pre_prompt(points=None, boxes=None, masks=None, device=None):
points_torch = points
if points != None:
# points = points/16.0
points_torch = torch.as_tensor(points, dtype=torch.float, device=device)
points_torch = points_torch/16.0
boxes_torch = boxes
if boxes != None:
# boxes = boxes/16.0
boxes_torch = torch.as_tensor(boxes, dtype=torch.float, device=device)
boxes_torch = boxes_torch/16.0
# for box in boxes:
# left_top, bottom_right = box
masks_torch = masks
if masks != None:
masks_torch = torch.as_tensor(masks, dtype=torch.float, device=device)
return points_torch, boxes_torch, masks_torch
# def pre_prompt(
# point_coords: Optional[np.ndarray] = None,
# point_labels: Optional[np.ndarray] = None,
# box: Optional[np.ndarray] = None,
# mask_input: Optional[np.ndarray] = None,
# device=None,
# original_size = [1024, 1024]
# ):
#
# transform = ResizeLongestSide(1024)
# # Transform input prompts
# coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
# if point_coords is not None:
# assert (
# point_labels is not None
# ), "point_labels must be supplied if point_coords is supplied."
# point_coords = transform.apply_coords(point_coords, original_size)
# coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=device)
# labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=device)
# coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
# if box is not None:
# box = transform.apply_boxes(box, original_size)
# box_torch = torch.as_tensor(box, dtype=torch.float, device=device)
# box_torch = box_torch[None, :]
# if mask_input is not None:
# mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=device)
# mask_input_torch = mask_input_torch[None, :, :, :]
#
# return coords_torch, labels_torch, box_torch, mask_input_torch
if __name__ == '__main__':
txt_dir = './ISAID/train/trainprompt/sub_labelTxt/'
txt_list = os.listdir(txt_dir)
txt_file_0 = txt_dir + txt_list[0]
points, labels, boxes, masks = get_prompt_inp(txt_file_0)
print(points)
print(labels)
print(boxes)
# boxes = boxes / 16.0
boxes_torch = torch.as_tensor(boxes, dtype=torch.float)
boxes_torch = boxes_torch/16.0
print(boxes_torch, boxes_torch.shape)
|