|
|
|
|
|
""" |
|
Created on Sun Aug 13 20:57:28 2023 |
|
|
|
@author: fujidai |
|
""" |
|
|
|
|
|
import torch |
|
from sentence_transformers import SentenceTransformer, InputExample, losses,models |
|
from sentence_transformers import SentenceTransformer, SentencesDataset, LoggingHandler, losses |
|
from sentence_transformers.readers import InputExample |
|
from torch.utils.data import DataLoader |
|
from transformers import AutoTokenizer |
|
from sentence_transformers.SentenceTransformer import SentenceTransformer |
|
import torch |
|
import torch.nn.functional as F |
|
import numpy as np |
|
from sentence_transformers import SentenceTransformer, util |
|
|
|
|
|
|
|
word_embedding_model = models.Transformer('/paraphrase-mpnet-base-v2', max_seq_length=512) |
|
|
|
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension()) |
|
|
|
model = SentenceTransformer(modules=[word_embedding_model, pooling_model],device='mps') |
|
print(model) |
|
|
|
|
|
with open('/WMT_da_学習データ_88993文/en-label-正規化.txt', 'r') as f: |
|
|
|
raberu = f.read() |
|
raberu_lines = raberu.splitlines() |
|
data = [] |
|
for i in range(len(raberu_lines)): |
|
data.append(float(raberu_lines[i])) |
|
|
|
|
|
|
|
|
|
with open('/WMT_da_学習データ_88993文/en-origin.txt', 'r') as f: |
|
left = f.read() |
|
left_lines = left.splitlines() |
|
|
|
with open('/WMT_da_学習データ_88993文/en-pseudo.txt', 'r') as f: |
|
senter = f.read() |
|
senter_lines = senter.splitlines() |
|
|
|
with open('/WMT_da_学習データ_88993文/en-pseudo-pseudo.txt', 'r') as f: |
|
right = f.read() |
|
right_lines = right.splitlines() |
|
|
|
|
|
train_examples = [] |
|
for i in range(len(left_lines)): |
|
pair=[] |
|
pair.append(left_lines[i]) |
|
pair.append(senter_lines[i]) |
|
pair.append(right_lines[i]) |
|
|
|
example = InputExample(texts=pair, label=data[i]) |
|
|
|
|
|
|
|
train_examples.append(example) |
|
print(len(train_examples)) |
|
|
|
|
|
device = torch.device('mps') |
|
|
|
|
|
import torch.nn.functional as F |
|
|
|
|
|
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=8) |
|
|
|
train_loss = losses.CosineSimilarityLoss(model) |
|
|
|
|
|
|
|
model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=100, warmup_steps=100,show_progress_bar=True, |
|
|
|
checkpoint_path='checkpoint_save_name',checkpoint_save_steps=11125, |
|
save_best_model=True, |
|
|
|
|
|
) |
|
model.save("last_save_name") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|