pt-sk commited on
Commit
635d0d0
1 Parent(s): c30373a

Upload dataset.py

Browse files
Files changed (1) hide show
  1. Others/dataset.py +90 -0
Others/dataset.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.data import Dataset
4
+
5
+ class BilingualDataset(Dataset):
6
+
7
+ def __init__(self, ds, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len):
8
+ super().__init__()
9
+ self.seq_len = seq_len
10
+
11
+ self.ds = ds
12
+ self.tokenizer_src = tokenizer_src
13
+ self.tokenizer_tgt = tokenizer_tgt
14
+ self.src_lang = src_lang
15
+ self.tgt_lang = tgt_lang
16
+
17
+ self.sos_token = torch.tensor([tokenizer_tgt.token_to_id("[SOS]")], dtype=torch.int64)
18
+ self.eos_token = torch.tensor([tokenizer_tgt.token_to_id("[EOS]")], dtype=torch.int64)
19
+ self.pad_token = torch.tensor([tokenizer_tgt.token_to_id("[PAD]")], dtype=torch.int64)
20
+
21
+ def __len__(self):
22
+ return len(self.ds)
23
+
24
+ def __getitem__(self, idx):
25
+ src_target_pair = self.ds[idx]
26
+ src_text = src_target_pair['translation'][self.src_lang]
27
+ tgt_text = src_target_pair['translation'][self.tgt_lang]
28
+
29
+ # Transform the text into tokens
30
+ enc_input_tokens = self.tokenizer_src.encode(src_text).ids
31
+ dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids
32
+
33
+ # Add sos, eos and padding to each sentence
34
+ enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2 # We will add <s> and </s>
35
+ # We will only add <s>, and </s> only on the label
36
+ dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1
37
+
38
+ # Make sure the number of padding tokens is not negative. If it is, the sentence is too long
39
+ if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:
40
+ raise ValueError("Sentence is too long")
41
+
42
+ # Add <s> and </s> token
43
+ encoder_input = torch.cat(
44
+ [
45
+ self.sos_token,
46
+ torch.tensor(enc_input_tokens, dtype=torch.int64),
47
+ self.eos_token,
48
+ torch.tensor([self.pad_token] * enc_num_padding_tokens, dtype=torch.int64),
49
+ ],
50
+ dim=0,
51
+ )
52
+
53
+ # Add only <s> token
54
+ decoder_input = torch.cat(
55
+ [
56
+ self.sos_token,
57
+ torch.tensor(dec_input_tokens, dtype=torch.int64),
58
+ torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
59
+ ],
60
+ dim=0,
61
+ )
62
+
63
+ # Add only </s> token
64
+ label = torch.cat(
65
+ [
66
+ torch.tensor(dec_input_tokens, dtype=torch.int64),
67
+ self.eos_token,
68
+ torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
69
+ ],
70
+ dim=0,
71
+ )
72
+
73
+ # Double check the size of the tensors to make sure they are all seq_len long
74
+ assert encoder_input.size(0) == self.seq_len
75
+ assert decoder_input.size(0) == self.seq_len
76
+ assert label.size(0) == self.seq_len
77
+
78
+ return {
79
+ "encoder_input": encoder_input, # (seq_len)
80
+ "decoder_input": decoder_input, # (seq_len)
81
+ "encoder_mask": (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(), # (1, 1, seq_len)
82
+ "decoder_mask": (decoder_input != self.pad_token).unsqueeze(0).int() & causal_mask(decoder_input.size(0)), # (1, seq_len) & (1, seq_len, seq_len),
83
+ "label": label, # (seq_len)
84
+ "src_text": src_text,
85
+ "tgt_text": tgt_text,
86
+ }
87
+
88
+ def causal_mask(size):
89
+ mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
90
+ return mask == 0