GPT007 commited on
Commit
5b8c471
1 Parent(s): 3e3f29f

Update PrateritumGPT.py

Browse files
Files changed (1) hide show
  1. PrateritumGPT.py +4 -2
PrateritumGPT.py CHANGED
@@ -142,7 +142,7 @@ train_loader = DataLoader(MyDataset, batch_size=32, shuffle=True, collate_fn=col
142
  #Dropout: 0
143
  #Forward Dim: 1024
144
 
145
- model = TransformerModel(vocab_size=len(tokens)+2, emb_dim=128, nhead=32, num_encoder_layers=1, num_decoder_layers=1, dim_feedforward=1024,dropout=0)
146
  loss_fn = nn.CrossEntropyLoss()
147
  optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
148
 
@@ -166,8 +166,10 @@ def Prompt():
166
  tgt_=torch.Tensor(tgt)
167
  out=model(torch.Tensor(src).to(device),tgt_.to(device)).tolist()[0]
168
  Best=0
169
- Best_=tokens.index(" ")
170
  for k,f in enumerate(out):
 
 
171
  if f>Best:
172
  Best=f
173
  Best_=k
 
142
  #Dropout: 0
143
  #Forward Dim: 1024
144
 
145
+ model = TransformerModel(vocab_size=len(tokens)+2, emb_dim=128, nhead=32, num_encoder_layers=1, num_decoder_layers=1, dim_feedforward=512,dropout=0)
146
  loss_fn = nn.CrossEntropyLoss()
147
  optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
148
 
 
166
  tgt_=torch.Tensor(tgt)
167
  out=model(torch.Tensor(src).to(device),tgt_.to(device)).tolist()[0]
168
  Best=0
169
+ warn=tokens.index(" ")
170
  for k,f in enumerate(out):
171
+ if k==len(tokens):
172
+ f*=2
173
  if f>Best:
174
  Best=f
175
  Best_=k