MarkProMaster229 commited on
Commit
333e7f3
·
verified ·
1 Parent(s): 70ba539

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +14 -15
README.md CHANGED
@@ -1,9 +1,8 @@
1
  ```python
2
- checkpoint_path = "model_checkpoint.pt"
3
-
4
  if os.path.exists(checkpoint_path):
5
  checkpoint = torch.load(checkpoint_path)
6
  embedding_layer.load_state_dict(checkpoint['embedding_state'])
 
7
  transformer_encoderLayer.load_state_dict(checkpoint['transformer_state'])
8
  output_layer.load_state_dict(checkpoint['output_state'])
9
  optimizer.load_state_dict(checkpoint['optimizer_state'])
@@ -13,12 +12,12 @@ else:
13
  start_epoch = 0
14
  print(" Чекпоинт не найден, начинаем обучение с нуля")
15
 
16
-
17
- epochNum = 20
18
  for epoch in range(epochNum):
19
  optimizer.zero_grad()
20
  epochmy = start_epoch + epoch
21
  embedded = embedding_layer(input_ids)
 
22
  src = embedded.transpose(0, 1)
23
 
24
  outputTransformer = transformer_encoderLayer(src, src_key_padding_mask=(attention_mask == 0))
@@ -26,27 +25,27 @@ for epoch in range(epochNum):
26
 
27
  logits = output_layer(outputTransformer)
28
  loss = criterion(logits.view(-1, vocab_size), target_ids.view(-1))
 
29
  loss.backward()
30
- optimizer.step()
 
 
 
31
 
 
32
  with torch.no_grad():
33
  embedded = embedding_layer(input_ids)
 
34
  src = embedded.transpose(0, 1)
35
  outputTransformer = transformer_encoderLayer(src, src_key_padding_mask=(attention_mask == 0))
36
  outputTransformer = outputTransformer.transpose(0, 1)
37
  logits = output_layer(outputTransformer) # [batch, seq_len, vocab_size]
38
 
39
-
40
  predicted_token_ids = torch.argmax(logits, dim=-1) # [batch, seq_len]
41
 
42
- predicted_text = tokenizer.batch_decode(predicted_token_ids, skip_special_tokens=True)
 
43
  print("Predicted text:", predicted_text[0])
44
 
45
- print(f"Epoch [{epoch + 1}/{epochNum}] — Loss: {loss.item():.4f}")
46
- torch.save({
47
- 'embedding_state': embedding_layer.state_dict(),
48
- 'transformer_state': transformer_encoderLayer.state_dict(),
49
- 'output_state': output_layer.state_dict(),
50
- 'optimizer_state': optimizer.state_dict(),
51
- 'epoch': epochmy
52
- }, "model_checkpoint.pt")
 
1
  ```python
 
 
2
  if os.path.exists(checkpoint_path):
3
  checkpoint = torch.load(checkpoint_path)
4
  embedding_layer.load_state_dict(checkpoint['embedding_state'])
5
+ pos_encoding.load_state_dict(checkpoint['pos_encoding_state'])
6
  transformer_encoderLayer.load_state_dict(checkpoint['transformer_state'])
7
  output_layer.load_state_dict(checkpoint['output_state'])
8
  optimizer.load_state_dict(checkpoint['optimizer_state'])
 
12
  start_epoch = 0
13
  print(" Чекпоинт не найден, начинаем обучение с нуля")
14
 
15
+ epochNum = 10
 
16
  for epoch in range(epochNum):
17
  optimizer.zero_grad()
18
  epochmy = start_epoch + epoch
19
  embedded = embedding_layer(input_ids)
20
+ embedded = pos_encoding(embedded)
21
  src = embedded.transpose(0, 1)
22
 
23
  outputTransformer = transformer_encoderLayer(src, src_key_padding_mask=(attention_mask == 0))
 
25
 
26
  logits = output_layer(outputTransformer)
27
  loss = criterion(logits.view(-1, vocab_size), target_ids.view(-1))
28
+ before = pos_encoding.pos_embedding.weight.clone()
29
  loss.backward()
30
+ optimizer.step() # обновляем веса
31
+ after = pos_encoding.pos_embedding.weight
32
+ print(f"Изменение весов pos_encoding: {(after - before).abs().sum():.6f}")
33
+ print("Loss:", loss.item())
34
 
35
+ # После обучения (или внутри цикла, чтобы смотреть динамику)
36
  with torch.no_grad():
37
  embedded = embedding_layer(input_ids)
38
+ embedded = pos_encoding(embedded)
39
  src = embedded.transpose(0, 1)
40
  outputTransformer = transformer_encoderLayer(src, src_key_padding_mask=(attention_mask == 0))
41
  outputTransformer = outputTransformer.transpose(0, 1)
42
  logits = output_layer(outputTransformer) # [batch, seq_len, vocab_size]
43
 
44
+ # Берём самый вероятный токен для каждого положения
45
  predicted_token_ids = torch.argmax(logits, dim=-1) # [batch, seq_len]
46
 
47
+ # Переводим индексы обратно в текст
48
+ predicted_text = tokenizer.batch_decode(predicted_token_ids, skip_special_tokens=False)
49
  print("Predicted text:", predicted_text[0])
50
 
51
+ print("Loss before backward:", loss.item())