WMT_da-data_finetuning / distillation.py
F-Haru's picture
Update distillation.py
f617fea
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Jun 17 16:20:22 2023
@author: fujidai
"""
from sentence_transformers import SentenceTransformer, LoggingHandler, models, evaluation, losses
import torch
from torch.utils.data import DataLoader
from sentence_transformers.datasets import ParallelSentencesDataset
from datetime import datetime
import os
import logging
import sentence_transformers.util
import csv
import gzip
from tqdm.autonotebook import tqdm
import numpy as np
import zipfile
import io
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)
teacher_model_name = 'WMT_da_finetuning.pyで作成した「教師」モデル' #Our monolingual teacher model, we want to convert to multiple languages
student_model_name = 'WMT_da_finetuning.pyで作成した「生徒」モデル' #Multilingual base model we use to imitate the teacher model
max_seq_length = 128 #Student model max. lengths for inputs (number of word pieces)
train_batch_size = 64 #Batch size for training
inference_batch_size = 64 #Batch size at inference
max_sentences_per_language = 500000 #Maximum number of parallel sentences for training
train_max_sentence_length = 250 #Maximum length (characters) for parallel training sentences
num_epochs = 100 #Train for x epochs
num_warmup_steps = 10000 #Warumup steps
num_evaluation_steps = 1000 #Evaluate performance after every xxxx steps
dev_sentences = 1000 #Number of parallel sentences to be used for development
######## Start the extension of the teacher model to multiple languages ########
logger.info("Load teacher model")
teacher_model = SentenceTransformer(teacher_model_name,device='mps')
logger.info("Create student model from scratch")
word_embedding_model = models.Transformer(student_model_name, max_seq_length=max_seq_length)
# Apply mean pooling to get one fixed sized sentence vector
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())#denseで次元数を768にする次元数をいじる
student_model = SentenceTransformer(modules=[word_embedding_model, pooling_model],device='mps')
print(teacher_model)
print(student_model)
from sentence_transformers.datasets import ParallelSentencesDataset
train_data = ParallelSentencesDataset(student_model=student_model, teacher_model=teacher_model)
train_data.load_data('/WMT_da_学習データ_88993文/tab_en-other.txt')# 英語 タブ 他の言語 というようになっている文
#train_data.load_data('/Users/fujidai/TED2020_data/data/tuikazumi/en-ja/TED2020.en-ja.en')
train_dataloader = DataLoader(train_data, shuffle=True, batch_size=train_batch_size)
train_loss = losses.MSELoss(model=student_model)
print(train_data)
#50000_all-MiniLM-L6-v2__paraphrase-distilroberta-base-v2_epoch-1
# Train the model
#print('az')
student_model.fit(train_objectives=[(train_dataloader, train_loss)],
epochs=num_epochs,
#device=device,
warmup_steps=num_warmup_steps,
evaluation_steps=num_evaluation_steps,
optimizer_params= {'lr': 2e-5, 'eps': 1e-6},
checkpoint_path='checkpoint_savename',
checkpoint_save_steps=1947#その時に応じて変更する
)
student_model.save('savename')
#