Vrk commited on
Commit
f95c696
1 Parent(s): 89d03ac
Files changed (1) hide show
  1. Dataset.py +61 -0
Dataset.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.data import DataLoader, Dataset
4
+ import tensorflow as tf
5
+ import numpy as np
6
+
7
+ def pad_sequences(sequences, max_seq_len=0):
8
+ """Pad sequences to max length in sequence."""
9
+ max_seq_len = max(max_seq_len, max(len(sequence) for sequence in sequences))
10
+ padded_sequences = np.zeros((len(sequences), max_seq_len))
11
+ for i, sequence in enumerate(sequences):
12
+ padded_sequences[i][:len(sequence)] = sequence
13
+ return padded_sequences
14
+
15
+ class SkimlitDataset(Dataset):
16
+ def __init__(self, text_seq, line_num, total_line):
17
+ self.text_seq = text_seq
18
+ self.line_num_one_hot = line_num
19
+ self.total_line_one_hot = total_line
20
+
21
+ def __len__(self):
22
+ return len(self.text_seq)
23
+
24
+ def __str__(self):
25
+ return f"<Dataset(N={len(self)})>"
26
+
27
+ def __getitem__(self, index):
28
+ X = self.text_seq[index]
29
+ line_num = self.line_num_one_hot[index]
30
+ total_line = self.total_line_one_hot[index]
31
+ return [X, len(X), line_num, total_line]
32
+
33
+ def collate_fn(self, batch):
34
+ """Processing on a batch"""
35
+ # Getting Input
36
+ batch = np.array(batch)
37
+ text_seq = batch[:,0]
38
+ seq_lens = batch[:, 1]
39
+ line_nums = batch[:, 2]
40
+ total_lines = batch[:, 3]
41
+
42
+ # padding inputs
43
+ pad_text_seq = pad_sequences(sequences=text_seq) # max_seq_len=max_length
44
+
45
+ # converting line nums into one-hot encoding
46
+ line_nums = tf.one_hot(line_nums, depth=20)
47
+
48
+ # converting total lines into one-hot encoding
49
+ total_lines = tf.one_hot(total_lines, depth=24)
50
+
51
+ # converting inputs to tensors
52
+ pad_text_seq = torch.LongTensor(pad_text_seq.astype(np.int32))
53
+ seq_lens = torch.LongTensor(seq_lens.astype(np.int32))
54
+ line_nums = torch.tensor(line_nums.numpy())
55
+ total_lines = torch.tensor(total_lines.numpy())
56
+
57
+ return pad_text_seq, seq_lens, line_nums, total_lines
58
+
59
+ def create_dataloader(self, batch_size, shuffle=False, drop_last=False):
60
+ dataloader = DataLoader(dataset=self, batch_size=batch_size, collate_fn=self.collate_fn, shuffle=shuffle, drop_last=drop_last, pin_memory=True)
61
+ return dataloader