File size: 3,936 Bytes
d6fa90a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb4d78d
 
d6fa90a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Jun 17 16:20:22 2023

@author: fujidai
"""


from sentence_transformers import SentenceTransformer, LoggingHandler, models, evaluation, losses
import torch
from torch.utils.data import DataLoader
from sentence_transformers.datasets import ParallelSentencesDataset
from datetime import datetime

import os
import logging
import sentence_transformers.util
import csv
import gzip
from tqdm.autonotebook import tqdm
import numpy as np
import zipfile
import io

logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)


teacher_model_name = '/Users/fujidai/sinTED/paraphrase-mpnet-base-v2'   #Our monolingual teacher model, we want to convert to multiple languages
#teacher_model_name = '/Users/fujidai/TED2020_data/tisikizyouryu/bert-large-nli-mean-tokens'   #Our monolingual teacher model, we want to convert to multiple languages

student_model_name = '/Users/fujidai/dataseigen/09-MarginMSELoss-finetuning-7-5'       #Multilingual base model we use to imitate the teacher model

max_seq_length = 128                #Student model max. lengths for inputs (number of word pieces)
train_batch_size = 64               #Batch size for training
inference_batch_size = 64           #Batch size at inference
max_sentences_per_language = 500000 #Maximum number of  parallel sentences for training
train_max_sentence_length = 250     #Maximum length (characters) for parallel training sentences

num_epochs = 3                       #Train for x epochs
num_warmup_steps = 10000             #Warumup steps

num_evaluation_steps = 1000          #Evaluate performance after every xxxx steps
dev_sentences = 1000                 #Number of parallel sentences to be used for development


######## Start the extension of the teacher model to multiple languages ########
logger.info("Load teacher model")
teacher_model = SentenceTransformer(teacher_model_name,device='mps')


logger.info("Create student model from scratch")

word_embedding_model = models.Transformer(student_model_name, max_seq_length=max_seq_length)
# Apply mean pooling to get one fixed sized sentence vector
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())#denseで次元数を768にする次元数をいじる
student_model = SentenceTransformer(modules=[word_embedding_model, pooling_model],device='mps')

print(teacher_model)
print(student_model)


from sentence_transformers.datasets import ParallelSentencesDataset

train_data = ParallelSentencesDataset(student_model=student_model, teacher_model=teacher_model)
train_data.load_data('/Users/fujidai/dataseigen/09-04_09-04.txt')#日本語英語をタブで繋げたやつ
#train_data.load_data('/Users/fujidai/TED2020_data/wmt21/output-100.txt')#日本語英語をタブで繋げたやつ

#train_data.load_data('/Users/fujidai/TED2020_data/data/tuikazumi/en-ja/TED2020.en-ja.en')
train_dataloader = DataLoader(train_data, shuffle=True, batch_size=train_batch_size)
train_loss = losses.MSELoss(model=student_model)

print(train_data)


#50000_all-MiniLM-L6-v2__paraphrase-distilroberta-base-v2_epoch-1

# Train the model
print('az')
student_model.fit(train_objectives=[(train_dataloader, train_loss)],
          epochs=num_epochs,
          #device=device,
          warmup_steps=num_warmup_steps,
          evaluation_steps=num_evaluation_steps,
          #output_path='best_paraphrase-mpnet-base-v2__xlm-roberta-base_epoch-3',
          #save_best_model=True,
          optimizer_params= {'lr': 2e-5, 'eps': 1e-6},
          checkpoint_path='paraphrase-mpnet-base-v2_09-MarginMSELoss-finetuning-7-5_2',
          checkpoint_save_steps=820
          )

student_model.save('paraphrase-mpnet-base-v2_09-MarginMSELoss-finetuning-7-5')











#