Upload 7 files
Browse files- .gitattributes +1 -0
- en-ja-100000-karanasi.txt +0 -0
- output-100000-karanasi.txt +3 -0
- pseudo-english-sentence-100000-karanasi.txt +0 -0
- pseudo-english_english_100000_cos-sim-karanasi.txt +0 -0
- pseudo-pseudo-english-sentence-100000-karanasi.txt +0 -0
- pseudo-pseudo-english_english_100000_cos-sim-karanasi.txt +0 -0
- teacher_finetune.py +95 -0
.gitattributes
CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
完成2-MarginMSELoss-finetuning-6-30/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
|
|
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
完成2-MarginMSELoss-finetuning-6-30/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
37 |
+
output-100000-karanasi.txt filter=lfs diff=lfs merge=lfs -text
|
en-ja-100000-karanasi.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
output-100000-karanasi.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:213ebc63bb4cfd4b740097909174855d3d64cfba85977d4620768d615d22b27b
|
3 |
+
size 20491048
|
pseudo-english-sentence-100000-karanasi.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
pseudo-english_english_100000_cos-sim-karanasi.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
pseudo-pseudo-english-sentence-100000-karanasi.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
pseudo-pseudo-english_english_100000_cos-sim-karanasi.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
teacher_finetune.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Created on Fri Jun 30 08:47:31 2023
|
5 |
+
|
6 |
+
@author: fujidai
|
7 |
+
"""
|
8 |
+
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from sentence_transformers import SentenceTransformer, InputExample, losses,models
|
12 |
+
from sentence_transformers import SentenceTransformer, SentencesDataset, LoggingHandler, losses
|
13 |
+
from sentence_transformers.readers import InputExample
|
14 |
+
from torch.utils.data import DataLoader
|
15 |
+
from transformers import AutoTokenizer
|
16 |
+
from sentence_transformers.SentenceTransformer import SentenceTransformer
|
17 |
+
import torch
|
18 |
+
import torch.nn.functional as F
|
19 |
+
import numpy as np
|
20 |
+
from sentence_transformers import SentenceTransformer, util
|
21 |
+
|
22 |
+
|
23 |
+
word_embedding_model = models.Transformer('paraphrase-mpnet-base-v2', max_seq_length=512)# modelの指定をする
|
24 |
+
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
|
25 |
+
#dense_model = models.Dense(in_features=pooling_model.get_sentence_embedding_dimension(),out_features=16)
|
26 |
+
model = SentenceTransformer(modules=[word_embedding_model, pooling_model],device='mps')
|
27 |
+
print(model)
|
28 |
+
|
29 |
+
|
30 |
+
with open('pseudo-pseudo-english_english_100000_cos-sim-karanasi.txt', 'r') as f:# pseudo-pseudo-english と english の cos_sim (en-jaのnegative-cossim)
|
31 |
+
|
32 |
+
raberu = f.read()
|
33 |
+
raberu_lines = raberu.splitlines()#改行コードごとにリストに入れている
|
34 |
+
data = []
|
35 |
+
for i in range(len(raberu_lines)):
|
36 |
+
data.append(float(raberu_lines[i]))#Negative en-ja cos_simをdataに入れている
|
37 |
+
|
38 |
+
|
39 |
+
with open('pseudo-english_english_100000_cos-sim-karanasi.txt', 'r') as f:## pseudo-english と english の cos_sim (ja-enのnegative-cossim)
|
40 |
+
raberu2 = f.read()
|
41 |
+
raberu2_lines = raberu2.splitlines()#改行コードごとにリストに入れている
|
42 |
+
data2 = []
|
43 |
+
for i in range(len(raberu2_lines)):
|
44 |
+
data2.append(float(raberu2_lines[i]))#Negative ja-en cos_simをdata2に入れている
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
with open('en-ja-100000-karanasi.txt', 'r') as f:#TEDのenglish
|
49 |
+
left = f.read()
|
50 |
+
left_lines = left.splitlines()
|
51 |
+
|
52 |
+
with open('pseudo-pseudo-english-sentence-100000-karanasi.txt', 'r') as f:#TEDのenglishをgoogle翻訳に入れて作った日本語をgoogle翻訳に入れて英語にしたやつ
|
53 |
+
senter = f.read()
|
54 |
+
senter_lines = senter.splitlines()
|
55 |
+
|
56 |
+
with open('pseudo-english-sentence-100000-karanasi.txt', 'r') as f:#TEDのjapaneseををgoogle翻訳に入れて作った英語
|
57 |
+
right = f.read()
|
58 |
+
right_lines = right.splitlines()#改行コードごとにリストに入れている
|
59 |
+
|
60 |
+
|
61 |
+
train_examples = []
|
62 |
+
for i in range(len(left_lines)):
|
63 |
+
pair=[]
|
64 |
+
pair.append(left_lines[i])#left_lines側のi行目をtextsに追加している
|
65 |
+
pair.append(senter_lines[i])
|
66 |
+
pair.append(right_lines[i])#right_lines側のi行目をtextsに追加している
|
67 |
+
#print(data[i]-data2[i])
|
68 |
+
absolutely=abs(data[i]-data2[i])
|
69 |
+
#print('zettai↓')
|
70 |
+
#print(absolutely)
|
71 |
+
example = InputExample(texts=pair, label=absolutely)#textsをラベル付きで追加している
|
72 |
+
#print(example)
|
73 |
+
#label=1-data[i]の1は positive cos_sim
|
74 |
+
train_examples.append(example)#学習として入れるものに入れている
|
75 |
+
|
76 |
+
|
77 |
+
|
78 |
+
|
79 |
+
|
80 |
+
device = torch.device('mps')
|
81 |
+
#print(device)
|
82 |
+
|
83 |
+
import torch.nn.functional as F
|
84 |
+
|
85 |
+
|
86 |
+
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=8)
|
87 |
+
train_loss = losses.MarginMSELoss(model=model,similarity_fct=F.cosine_similarity)
|
88 |
+
|
89 |
+
|
90 |
+
#Tune the model
|
91 |
+
model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=100, warmup_steps=100,show_progress_bar=True,
|
92 |
+
#output_path='完成2best-6-30',
|
93 |
+
checkpoint_path='paraphrase-mpnet-base-v2_finetuning-2',checkpoint_save_steps=6195,#どのくらいのイテレーションごとに保存するか
|
94 |
+
save_best_model=True)#checkpoint_save_total_limit=5,
|
95 |
+
model.save("paraphrase-mpnet-base-v2_finetuning")
|