GPT007 commited on
Commit
6100a96
1 Parent(s): f0d1b23

Update PrateritumGPT.py

Browse files

Changed some values and added <BOS> token

Files changed (1) hide show
  1. PrateritumGPT.py +69 -31
PrateritumGPT.py CHANGED
@@ -5,8 +5,11 @@ from torch.utils.data import Dataset, DataLoader
5
  from torch.nn.utils.rnn import pad_sequence
6
  import math
7
  import progressbar
 
8
 
9
- device="cpu"
 
 
10
 
11
  def CreateBar():
12
  global bar
@@ -21,7 +24,7 @@ for i in range(len(tokens)):
21
  tokensdict.update({tokens[i]: [0] * i + [0] * (len(tokens) - (i + 1))})
22
 
23
  # Ouvrir le fichier CSV
24
- with open("C:\\Users\\marc2\\Downloads\\7eaaf0e22461b505c749e268c0b72bc4-12ebe211a929f039791dfeaa1a019b64cadddaf1\\7eaaf0e22461b505c749e268c0b72bc4-12ebe211a929f039791dfeaa1a019b64cadddaf1\\top-german-verbs.csv", 'r', encoding="utf-8") as file:
25
  # Créer un objet lecteur CSV
26
  reader = [i for i in csv.reader(file)][1:]
27
 
@@ -37,7 +40,6 @@ class CSVDataset(Dataset):
37
  sample = self.features[idx], self.labels[idx]
38
  return sample
39
 
40
- # Supposons que vous ayez vos données sous forme de listes
41
  features = []
42
  labels = []
43
  padding=len(tokens)
@@ -48,7 +50,7 @@ for i in reader:
48
  k += [tokens.index(j)]
49
  #k += [-1] * (25 - len(k))
50
  features += [torch.Tensor(k)]
51
- k = []
52
  for j in i[8]:
53
  k += [tokens.index(j)]
54
  #k += [-1] * (25 - len(k))
@@ -109,12 +111,41 @@ def collate_fn(batch):
109
 
110
  train_loader = DataLoader(MyDataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
111
 
112
- model = TransformerModel(vocab_size=len(tokens)+1, emb_dim=16, nhead=4, num_encoder_layers=2, num_decoder_layers=2, dim_feedforward=256)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  loss_fn = nn.CrossEntropyLoss()
114
  optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
115
 
116
- epochs = 100
117
-
118
  try:
119
  model.load_state_dict(torch.load("data/PrateritumGPT.pth"))
120
  print("Sucessfully loaded model.")
@@ -122,35 +153,42 @@ except:
122
  pass
123
 
124
  #print(model(torch.zeros((1,25)).to(device),torch.zeros((1,25)).to(device)))
125
- inp=input("Which verb? ")
126
- src=[[]]
127
- tgt=[[tokens.index(inp[0])]]
128
- for i in inp:
129
- src[0]+=[tokens.index(i)]
130
- str_=inp[0]
131
- for i in range(100):
132
- out=model(torch.Tensor(src).to(device),torch.Tensor(tgt).to(device)).tolist()[0]
133
- Best=0
134
- Best_=tokens.index(" ")
135
- for k,f in enumerate(out):
136
- if f>Best:
137
- Best=f
138
- Best_=k
139
- if Best_==len(tokens):
140
- break
141
- str_+=tokens[Best_]
142
- tgt[0]+=[Best_]
143
-
144
- print(str_)
145
-
 
 
 
 
 
 
 
 
 
146
 
147
  for epoch in range(epochs):
148
  total_loss = 0.0
149
 
150
  CreateBar()
151
 
152
- bar.start()
153
-
154
  for batch_idx, (inputs, targets) in enumerate(train_loader):
155
 
156
  #print("",inputs,targets)
@@ -168,7 +206,7 @@ for epoch in range(epochs):
168
 
169
  total_loss += loss.item()
170
 
171
- mask = targets[:, i] != len(tokens)
172
  targets = targets[mask]
173
  inputs = inputs[mask]
174
 
 
5
  from torch.nn.utils.rnn import pad_sequence
6
  import math
7
  import progressbar
8
+ import os
9
 
10
+ Path=os.path.dirname(os.path.abspath(__file__))+"\\"
11
+
12
+ device="cuda"
13
 
14
  def CreateBar():
15
  global bar
 
24
  tokensdict.update({tokens[i]: [0] * i + [0] * (len(tokens) - (i + 1))})
25
 
26
  # Ouvrir le fichier CSV
27
+ with open(Path+"top-german-verbs.csv", 'r', encoding="utf-8") as file:
28
  # Créer un objet lecteur CSV
29
  reader = [i for i in csv.reader(file)][1:]
30
 
 
40
  sample = self.features[idx], self.labels[idx]
41
  return sample
42
 
 
43
  features = []
44
  labels = []
45
  padding=len(tokens)
 
50
  k += [tokens.index(j)]
51
  #k += [-1] * (25 - len(k))
52
  features += [torch.Tensor(k)]
53
+ k = [len(tokens)+1]
54
  for j in i[8]:
55
  k += [tokens.index(j)]
56
  #k += [-1] * (25 - len(k))
 
111
 
112
  train_loader = DataLoader(MyDataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
113
 
114
+ #Embedding Dimension on epoch 10
115
+ #32:10.49
116
+ #64:6.55
117
+ #128:6.44
118
+ #256:9.63
119
+
120
+ #Head Number on epoch 15
121
+ #32:6.44
122
+ #64:5.17
123
+ #16:5.9402
124
+
125
+ #Feed Forward Dimension on epoch 15+ (minimum)
126
+ #128:5.17
127
+ #256:3.49
128
+ #512:3.44
129
+ #1024:3.23
130
+
131
+ #Num Encoder Layers on epochs 25 (minimum)
132
+ #1:3.15
133
+ #2:4.01
134
+
135
+ #Num Decoder Layers on epochs 25 (minimum)
136
+ #1:3.15
137
+ #2:2.14
138
+ #3:1.75
139
+ #4:1.60
140
+
141
+ #New model:
142
+ #Dropout: 0
143
+ #Forward Dim: 1024
144
+
145
+ model = TransformerModel(vocab_size=len(tokens)+2, emb_dim=128, nhead=32, num_encoder_layers=1, num_decoder_layers=1, dim_feedforward=1024,dropout=0)
146
  loss_fn = nn.CrossEntropyLoss()
147
  optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
148
 
 
 
149
  try:
150
  model.load_state_dict(torch.load("data/PrateritumGPT.pth"))
151
  print("Sucessfully loaded model.")
 
153
  pass
154
 
155
  #print(model(torch.zeros((1,25)).to(device),torch.zeros((1,25)).to(device)))
156
+ def Prompt():
157
+ global tokens
158
+ global model
159
+ inp=input("Give me a verb: ")
160
+ src=[[]]
161
+ tgt=[[len(tokens)+1]]
162
+ for i in inp:
163
+ src[0]+=[tokens.index(i)]
164
+ str_=""
165
+ for i in range(100):
166
+ tgt_=torch.Tensor(tgt)
167
+ out=model(torch.Tensor(src).to(device),tgt_.to(device)).tolist()[0]
168
+ Best=0
169
+ Best_=tokens.index(" ")
170
+ for k,f in enumerate(out):
171
+ if f>Best:
172
+ Best=f
173
+ Best_=k
174
+ if Best_==len(tokens):
175
+ break
176
+ str_+=tokens[Best_]
177
+ tgt[0]+=[Best_]
178
+
179
+ print(str_)
180
+
181
+ if eval(input('Train? ')):
182
+ epochs=eval(input("epochs "))
183
+ else:
184
+ while True:
185
+ Prompt()
186
 
187
  for epoch in range(epochs):
188
  total_loss = 0.0
189
 
190
  CreateBar()
191
 
 
 
192
  for batch_idx, (inputs, targets) in enumerate(train_loader):
193
 
194
  #print("",inputs,targets)
 
206
 
207
  total_loss += loss.item()
208
 
209
+ mask = targets[:, i] != padding
210
  targets = targets[mask]
211
  inputs = inputs[mask]
212