File size: 4,643 Bytes
8d7921b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# coding:utf-8
import os

import numpy as np
import cv2
from typing import Optional
import torch

# from models.transforms import ResizeLongestSide
# from .transforms import ResizeLongestSide
from torchvision import transforms

def get_prompt_inp_scatter(scatter_file_):

    scatter_mask = cv2.imread(scatter_file_, cv2.IMREAD_UNCHANGED)

    return scatter_mask

def pre_scatter_prompt(scatter, filp, device):
    if filp == True:
        scatter = cv2.flip(scatter, 1)

    img_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    scatter_torch = img_transform(scatter)
    scatter_torch = scatter_torch.to(device)
    return scatter_torch

def get_prompt_inp(txt_file_, filp):

    f = open(txt_file_)
    lines = f.readlines()
    points = []
    labels = []
    boxes = []
    masks = []
    for line in lines:
        x_1, y_1, x_2, y_2, x_3, y_3, x_4, y_4, classname, _ = line.split(' ')
        # print(x_1, y_1, x_2, y_2, x_3, y_3, x_4, y_4, classname, _)
        x_1, y_1, x_2, y_2, x_3, y_3, x_4, y_4 = float(x_1), float(y_1), \
                                                 float(x_2), float(y_2), \
                                                 float(x_3), float(y_3), \
                                                 float(x_4), float(y_4)
        xmin = min(x_1, x_2, x_3, x_4)
        xmax = max(x_1, x_2, x_3, x_4)
        ymin = min(y_1, y_2, y_3, y_4)
        ymax = max(y_1, y_2, y_3, y_4)
        if filp:
            xmin = 1024.0 - xmin
            xmax = 1024.0 - xmax
        x_center = (xmin + xmax)/2
        y_center = (ymin + ymax)/2
        point = [x_center, y_center]
        box = [[xmin, ymin], [xmax, ymax]]
        # box = [xmin, ymin, xmax, ymax]
        mask = []
        points.append(point)
        labels.append(classname)
        boxes.append(box)
        masks.append(mask)
    # boxes = boxes[:1]
    # return points, labels, boxes, masks
    return points, labels, boxes, None

def pre_prompt(points=None, boxes=None, masks=None, device=None):

    points_torch = points
    if points != None:
        # points = points/16.0
        points_torch = torch.as_tensor(points, dtype=torch.float, device=device)
        points_torch = points_torch/16.0
    boxes_torch = boxes
    if boxes != None:
        # boxes = boxes/16.0
        boxes_torch = torch.as_tensor(boxes, dtype=torch.float, device=device)
        boxes_torch = boxes_torch/16.0
        # for box in boxes:
        #     left_top, bottom_right = box
    masks_torch = masks
    if masks != None:
        masks_torch = torch.as_tensor(masks, dtype=torch.float, device=device)

    return points_torch, boxes_torch, masks_torch



# def pre_prompt(
#         point_coords: Optional[np.ndarray] = None,
#         point_labels: Optional[np.ndarray] = None,
#         box: Optional[np.ndarray] = None,
#         mask_input: Optional[np.ndarray] = None,
#         device=None,
#         original_size = [1024, 1024]
#         ):
#
#     transform = ResizeLongestSide(1024)
#     # Transform input prompts
#     coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
#     if point_coords is not None:
#         assert (
#                 point_labels is not None
#         ), "point_labels must be supplied if point_coords is supplied."
#         point_coords = transform.apply_coords(point_coords, original_size)
#         coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=device)
#         labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=device)
#         coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
#     if box is not None:
#         box = transform.apply_boxes(box, original_size)
#         box_torch = torch.as_tensor(box, dtype=torch.float, device=device)
#         box_torch = box_torch[None, :]
#     if mask_input is not None:
#         mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=device)
#         mask_input_torch = mask_input_torch[None, :, :, :]
#
#     return coords_torch, labels_torch, box_torch, mask_input_torch

if __name__ == '__main__':
    txt_dir = './ISAID/train/trainprompt/sub_labelTxt/'
    txt_list = os.listdir(txt_dir)
    txt_file_0 = txt_dir + txt_list[0]
    points, labels, boxes, masks = get_prompt_inp(txt_file_0)
    print(points)
    print(labels)
    print(boxes)
    # boxes = boxes / 16.0
    boxes_torch = torch.as_tensor(boxes, dtype=torch.float)
    boxes_torch = boxes_torch/16.0
    print(boxes_torch, boxes_torch.shape)