#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Sun Aug 13 20:57:28 2023 @author: fujidai """ import torch from sentence_transformers import SentenceTransformer, InputExample, losses,models from sentence_transformers import SentenceTransformer, SentencesDataset, LoggingHandler, losses from sentence_transformers.readers import InputExample from torch.utils.data import DataLoader from transformers import AutoTokenizer from sentence_transformers.SentenceTransformer import SentenceTransformer import torch import torch.nn.functional as F import numpy as np from sentence_transformers import SentenceTransformer, util #word_embedding_model = models.Transformer('/paraphrase-multilingual-mpnet-base-v2', max_seq_length=512)# modelの指定をする word_embedding_model = models.Transformer('/paraphrase-mpnet-base-v2', max_seq_length=512)# modelの指定をする pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension()) #dense_model = models.Dense(in_features=pooling_model.get_sentence_embedding_dimension(),out_features=16) model = SentenceTransformer(modules=[word_embedding_model, pooling_model],device='mps') print(model) with open('/WMT_da_学習データ_88993文/en-label-正規化.txt', 'r') as f:#en-labelを正規化したもの raberu = f.read() raberu_lines = raberu.splitlines()#改行コードごとにリストに入れている data = [] for i in range(len(raberu_lines)): data.append(float(raberu_lines[i]))#Negative en-ja cos_simをdataに入れている with open('/WMT_da_学習データ_88993文/en-origin.txt', 'r') as f:#daの英語の文が入っているデータ(翻訳を一度もしていない) left = f.read() left_lines = left.splitlines() with open('/WMT_da_学習データ_88993文/en-pseudo.txt', 'r') as f:#daの英語じゃない方の文をgoogle翻訳に入れたものが入っているデータ senter = f.read() senter_lines = senter.splitlines() with open('/WMT_da_学習データ_88993文/en-pseudo-pseudo.txt', 'r') as f:#daの英語の文をgoogle翻訳に入れたものをもう一度google翻訳に入れたものが入っているデータ right = f.read() right_lines = right.splitlines()#改行コードごとにリストに入れている train_examples = [] for i in range(len(left_lines)): pair=[] pair.append(left_lines[i])#left_lines側のi行目をtextsに追加している pair.append(senter_lines[i]) pair.append(right_lines[i])#right_lines側のi行目をtextsに追加している example = InputExample(texts=pair, label=data[i])#textsをラベル付きで追加している #print(example)# #label=1-data[i]の1は positive cos_sim #if aq>=0.25: train_examples.append(example)#学習として入れるものに入れている print(len(train_examples)) device = torch.device('mps') #print(device) import torch.nn.functional as F train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=8) #train_loss = losses.MarginMSELoss(model=model,similarity_fct=F.cosine_similarity) train_loss = losses.CosineSimilarityLoss(model) #Tune the model model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=100, warmup_steps=100,show_progress_bar=True, #output_path='完成2best-6-30', checkpoint_path='checkpoint_save_name',checkpoint_save_steps=11125,#どのくらいのイテレーションごとに保存するか save_best_model=True,#,#checkpoint_save_total_limit=5 #optimizer_params= {'lr': 1e-05}# ) model.save("last_save_name") #