Seth0330 commited on
Commit
d2a96ab
·
verified ·
1 Parent(s): f54b486

Create mydatasets.py

Browse files
Files changed (1) hide show
  1. pdrt/mydatasets.py +128 -0
pdrt/mydatasets.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ from typing import Any
5
+ from ast import literal_eval
6
+ from torch.utils.data import Dataset
7
+
8
+ import paths
9
+ from utils_ctc import sample_text_to_seq
10
+
11
+ ######################################################
12
+ # Dataset Swin + CTC
13
+ ######################################################
14
+
15
+ class myDatasetCTC(Dataset):
16
+
17
+ def __init__(self, partition = "train"):
18
+
19
+ self.processor = None
20
+ self.partition = partition
21
+
22
+ self.path_labels = paths.IMAGE_PATH
23
+ self.path_images = paths.GT_PATH
24
+ self.image_name_list = []
25
+ self.label_list = []
26
+
27
+ f = open(self.path_labels, 'r')
28
+ Lines = f.readlines()
29
+
30
+ for line in Lines:
31
+ line = line.strip().split()
32
+ self.image_name_list.append(self.path_images + line[0])
33
+ self.label_list.append(' '.join(line[1:]))
34
+
35
+ print("\tSamples Loaded: ", len(self.label_list), "\n-------------------------------------")
36
+
37
+ def set_processor(self, processor):
38
+ self.processor = processor
39
+
40
+ def __len__(self):
41
+ return len(self.image_name_list)
42
+
43
+ def __getitem__(self, idx):
44
+
45
+ with Image.open(self.image_name_list[idx]) as image:
46
+ image = image.convert("RGB")
47
+ image_tensor = np.array(image)
48
+ label = self.label_list[idx]
49
+
50
+ image_tensor = self.processor(
51
+ image_tensor,
52
+ random_padding=self.partitions == "train",
53
+ return_tensors="pt"
54
+ ).pixel_values
55
+ image_tensor = image_tensor.squeeze()
56
+
57
+ # ctc
58
+ label_tensor = torch.tensor(sample_text_to_seq(label, self.text_to_seq))
59
+
60
+ return {"idx": idx, "img": image_tensor, "label": label_tensor, "raw_label": label}
61
+
62
+
63
+ ######################################################
64
+ # Dataset Vision Encoder-Decoder (VED)
65
+ ######################################################
66
+
67
+ class myDatasetTransformerDecoder(Dataset):
68
+ def __init__(self, partition="train"):
69
+
70
+ self.max_length = paths.MAX_LENGTH
71
+ self.partition = partition
72
+ self.processor = None
73
+ self.ignore_id = -100
74
+
75
+ self.path_img = paths.IMAGE_PATH
76
+ self.path_transcriptions = paths.GT_PATH
77
+ self.image_name_list = []
78
+ self.label_list = []
79
+
80
+ template = '{"gt_parse": {"text_sequence" : '
81
+ with open(self.path_transcriptions, 'r') as file:
82
+ for line in file:
83
+ line = line.strip().split()
84
+
85
+ image_name = line[0]
86
+ label_gt = ' '.join(line[1:])
87
+ label_gt = template + '"' + label_gt + '"' + "}}"
88
+
89
+ self.image_name_list.append(self.path_img + image_name)
90
+ self.label_list.append(label_gt)
91
+
92
+ print("\tSamples Loaded: ", len(self.label_list))
93
+
94
+ def dict2token(self, obj: Any):
95
+ return obj["text_sequence"]
96
+
97
+ def set_processor(self, processor):
98
+ self.processor = processor
99
+
100
+ def __len__(self):
101
+ return len(self.image_name_list)
102
+
103
+ def __getitem__(self, idx):
104
+
105
+ image = Image.open(self.image_name_list[idx]).convert("RGB")
106
+ image_tensor = np.array(image)
107
+
108
+ pixel_values: torch.Tensor = self.processor(image_tensor, random_padding=self.partition == "train", return_tensors="pt").pixel_values[0]
109
+
110
+ label = self.label_list[idx]
111
+ label = literal_eval(label)
112
+ assert "gt_parse" in label and isinstance(label["gt_parse"], dict)
113
+ gt_dicts = [label["gt_parse"]]
114
+ target_sequence=[self.dict2token(gt_dict) + self.processor.tokenizer.eos_token for gt_dict in gt_dicts]
115
+
116
+ input_ids = self.processor.tokenizer(
117
+ target_sequence,
118
+ add_special_tokens=False,
119
+ max_length=self.max_length,
120
+ padding="max_length",
121
+ truncation=True,
122
+ return_tensors="pt",
123
+ )["input_ids"].squeeze(0)
124
+
125
+ labels = input_ids.clone()
126
+ labels[labels == self.processor.tokenizer.pad_token_id] = self.ignore_id # model doesn't need to predict pad token
127
+
128
+ return {"idx": idx, "img": pixel_values, "label": labels, "raw_label": target_sequence}