|
|
|
|
|
""" |
|
Created on Sat Jun 17 16:20:22 2023 |
|
|
|
@author: fujidai |
|
""" |
|
|
|
|
|
from sentence_transformers import SentenceTransformer, LoggingHandler, models, evaluation, losses |
|
import torch |
|
from torch.utils.data import DataLoader |
|
from sentence_transformers.datasets import ParallelSentencesDataset |
|
from datetime import datetime |
|
|
|
import os |
|
import logging |
|
import sentence_transformers.util |
|
import csv |
|
import gzip |
|
from tqdm.autonotebook import tqdm |
|
import numpy as np |
|
import zipfile |
|
import io |
|
|
|
logging.basicConfig(format='%(asctime)s - %(message)s', |
|
datefmt='%Y-%m-%d %H:%M:%S', |
|
level=logging.INFO, |
|
handlers=[LoggingHandler()]) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
teacher_model_name = 'WMT_da_finetuning.pyで作成した「教師」モデル' |
|
|
|
student_model_name = 'WMT_da_finetuning.pyで作成した「生徒」モデル' |
|
|
|
|
|
max_seq_length = 128 |
|
train_batch_size = 64 |
|
inference_batch_size = 64 |
|
max_sentences_per_language = 500000 |
|
train_max_sentence_length = 250 |
|
|
|
num_epochs = 100 |
|
num_warmup_steps = 10000 |
|
|
|
num_evaluation_steps = 1000 |
|
dev_sentences = 1000 |
|
|
|
|
|
|
|
logger.info("Load teacher model") |
|
teacher_model = SentenceTransformer(teacher_model_name,device='mps') |
|
|
|
|
|
logger.info("Create student model from scratch") |
|
|
|
word_embedding_model = models.Transformer(student_model_name, max_seq_length=max_seq_length) |
|
|
|
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension()) |
|
student_model = SentenceTransformer(modules=[word_embedding_model, pooling_model],device='mps') |
|
|
|
print(teacher_model) |
|
print(student_model) |
|
|
|
|
|
from sentence_transformers.datasets import ParallelSentencesDataset |
|
|
|
train_data = ParallelSentencesDataset(student_model=student_model, teacher_model=teacher_model) |
|
train_data.load_data('/WMT_da_学習データ_88993文/tab_en-other.txt') |
|
|
|
|
|
|
|
train_dataloader = DataLoader(train_data, shuffle=True, batch_size=train_batch_size) |
|
train_loss = losses.MSELoss(model=student_model) |
|
|
|
print(train_data) |
|
|
|
|
|
|
|
|
|
|
|
|
|
student_model.fit(train_objectives=[(train_dataloader, train_loss)], |
|
epochs=num_epochs, |
|
|
|
warmup_steps=num_warmup_steps, |
|
evaluation_steps=num_evaluation_steps, |
|
optimizer_params= {'lr': 2e-5, 'eps': 1e-6}, |
|
checkpoint_path='checkpoint_savename', |
|
checkpoint_save_steps=1947 |
|
) |
|
|
|
student_model.save('savename') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|