Update README.md
Browse files
README.md
CHANGED
@@ -20,6 +20,52 @@ base_model:
|
|
20 |
|
21 |
## Train
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
## Evaluation
|
25 |
|
|
|
20 |
|
21 |
## Train
|
22 |
|
23 |
+
H/W : colab A100 40GB
|
24 |
+
Data : jaeyong2/Ko-emb-PreView
|
25 |
+
|
26 |
+
```
|
27 |
+
model_name = "Alibaba-NLP/gte-multilingual-base"
|
28 |
+
dataset = datasets.load_dataset("jaeyong2/Ko-emb-PreView")
|
29 |
+
train_dataloader = DataLoader(dataset['train'], batch_size=8, shuffle=True)
|
30 |
+
|
31 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
32 |
+
model = AutoModel.from_pretrained(model_name).to(torch.bfloat16)
|
33 |
+
triplet_loss = TripletLoss(margin=1.0)
|
34 |
+
|
35 |
+
optimizer = AdamW(model.parameters(), lr=5e-5)
|
36 |
+
|
37 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
38 |
+
model.to(device)
|
39 |
+
|
40 |
+
for epoch in range(3): # 에포크 반복
|
41 |
+
model.train()
|
42 |
+
total_loss = 0
|
43 |
+
count = 0
|
44 |
+
for batch in tqdm(train_dataloader):
|
45 |
+
optimizer.zero_grad()
|
46 |
+
loss = None
|
47 |
+
for index in range(len(batch["context"])):
|
48 |
+
anchor_encodings = tokenizer([batch["context"][index]], truncation=True, padding="max_length", max_length=4096, return_tensors="pt")
|
49 |
+
positive_encodings = tokenizer([batch["Title"][index]], truncation=True, padding="max_length", max_length=256, return_tensors="pt")
|
50 |
+
negative_encodings = tokenizer([batch["Fake Title"][index]], truncation=True, padding="max_length", max_length=256, return_tensors="pt")
|
51 |
+
|
52 |
+
anchor_encodings = batch_to_device(anchor_encodings, device)
|
53 |
+
positive_encodings = batch_to_device(positive_encodings, device)
|
54 |
+
negative_encodings = batch_to_device(negative_encodings, device)
|
55 |
+
|
56 |
+
# 모델 출력 (임베딩 벡터 생성)
|
57 |
+
anchor_output = model(**anchor_encodings)[0][:, 0, :] # [CLS] 토큰의 벡터
|
58 |
+
positive_output = model(**positive_encodings)[0][:, 0, :]
|
59 |
+
negative_output = model(**negative_encodings)[0][:, 0, :]
|
60 |
+
# 삼중항 손실 계산
|
61 |
+
if loss==None:
|
62 |
+
loss = triplet_loss(anchor_output, positive_output, negative_output)
|
63 |
+
else:
|
64 |
+
loss += triplet_loss(anchor_output, positive_output, negative_output)
|
65 |
+
loss /= len(batch["context"])
|
66 |
+
loss.backward()
|
67 |
+
optimizer.step()
|
68 |
+
```
|
69 |
|
70 |
## Evaluation
|
71 |
|