File size: 3,826 Bytes
b9a2923
 
8ebec20
 
 
 
 
 
 
b9a2923
 
 
 
 
 
 
 
 
 
 
8ebec20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9a2923
 
 
8ebec20
 
 
 
 
 
b9a2923
 
8ebec20
 
b9a2923
8ebec20
b9a2923
8ebec20
 
b9a2923
8ebec20
 
 
 
 
b9a2923
8ebec20
 
 
b9a2923
8ebec20
 
b9a2923
8ebec20
 
b9a2923
8ebec20
 
 
b9a2923
8ebec20
 
 
b9a2923
 
8ebec20
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
---
library_name: transformers
license: apache-2.0
datasets:
- jaeyong2/Thai-emb-PreView
language:
- th
base_model:
- Alibaba-NLP/gte-multilingual-base
---

# Model Card for Model ID

<!-- Provide a quick summary of what the model is/does. -->



## Model Details


## Train

- H/W : colab A100 40GB
- Data : jaeyong2/Thai-emb-PreView

```
model_name = "Alibaba-NLP/gte-multilingual-base"
dataset = datasets.load_dataset("jaeyong2/Thai-emb-PreView")
train_dataloader = DataLoader(dataset['train'], batch_size=8, shuffle=True)

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(torch.bfloat16)
triplet_loss = TripletLoss(margin=1.0)

optimizer = AdamW(model.parameters(), lr=5e-5)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(3):  # 에포크 반복
    model.train()
    total_loss = 0
    count = 0
    for batch in tqdm(train_dataloader):
        optimizer.zero_grad()
        loss = None
        for index in range(len(batch["context"])):
            anchor_encodings = tokenizer([batch["context"][index]], truncation=True, padding="max_length", max_length=4096, return_tensors="pt")
            positive_encodings = tokenizer([batch["Title"][index]], truncation=True, padding="max_length", max_length=256, return_tensors="pt")
            negative_encodings = tokenizer([batch["Fake Title"][index]], truncation=True, padding="max_length", max_length=256, return_tensors="pt")

            anchor_encodings = batch_to_device(anchor_encodings, device)
            positive_encodings = batch_to_device(positive_encodings, device)
            negative_encodings = batch_to_device(negative_encodings, device)

            # 모델 출력 (임베딩 벡터 생성)
            anchor_output = model(**anchor_encodings)[0][:, 0, :]  # [CLS] 토큰의 벡터
            positive_output = model(**positive_encodings)[0][:, 0, :]
            negative_output = model(**negative_encodings)[0][:, 0, :]
            # 삼중항 손실 계산
            if loss==None:
                loss = triplet_loss(anchor_output, positive_output, negative_output)
            else:
                loss += triplet_loss(anchor_output, positive_output, negative_output)
        loss /= len(batch["context"])
        loss.backward()
        optimizer.step()
```

## Evaluation

Code : 
```
import torch
import numpy as np
from sklearn.metrics import pairwise_distances
from tqdm import tqdm


dataset = datasets.load_dataset("jaeyong2/Thai-emb-PreView")
validation_dataset = dataset["test"].select(range((1000)))

model.eval()

def evaluate(validation_dataset):
    correct_count = 0

    for item in tqdm(validation_dataset):
        query_embedding = get_embedding(item["context"], model, tokenizer)
        document_embedding = get_embedding(item["Title"], model, tokenizer)
        negative_embedding = get_embedding(item["Fake Title"], model, tokenizer)
      

        # 쿼리와 모든 문서 간의 유사도 계산 (코사인 거리 사용)
        positive_distances = pairwise_distances(query_embedding.detach().cpu().float().numpy(), document_embedding.detach().cpu().float().numpy(), metric="cosine")
        negative_distances = pairwise_distances(query_embedding.detach().cpu().float().numpy(), negative_embedding.detach().cpu().float().numpy(), metric="cosine")

        if positive_distances < negative_distances:
            correct_count += 1

    accuracy = correct_count / len(validation_dataset)
    return accuracy

results = evaluate(validation_dataset)
print(f"Validation Results: {results}")
```

Accuracy
- Alibaba-NLP/gte-multilingual-base : 0.953
- jaeyong2/gte-multilingual-base-Thai-embedding : 0.991


### License
- Alibaba-NLP/gte-multilingual-base : https://choosealicense.com/licenses/apache-2.0/