daekeun-ml commited on
Commit
78b9c85
โ€ข
1 Parent(s): 7e833b3

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +168 -0
README.md ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - ko
5
+ pipeline_tag: feature-extraction
6
+ ---
7
+
8
+ ---
9
+ license: mit
10
+ language:
11
+ - ko
12
+ pipeline_tag: feature-extraction
13
+ ---
14
+
15
+ # KoSimCSE Training on Amazon SageMaker
16
+
17
+
18
+ ## Usage
19
+
20
+ ```python
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from torch import Tensor
25
+ from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
26
+ from transformers import AutoModel, AutoTokenizer, logging
27
+
28
+ class SimCSEConfig(PretrainedConfig):
29
+ def __init__(self, version=1.0, **kwargs):
30
+ self.version = version
31
+ super().__init__(**kwargs)
32
+
33
+ class SimCSEModel(PreTrainedModel):
34
+ config_class = SimCSEConfig
35
+
36
+ def __init__(self, config):
37
+ super().__init__(config)
38
+ self.backbone = AutoModel.from_pretrained(config.base_model)
39
+ self.hidden_size: int = self.backbone.config.hidden_size
40
+ self.dense = nn.Linear(self.hidden_size, self.hidden_size)
41
+ self.activation = nn.Tanh()
42
+
43
+ def forward(
44
+ self,
45
+ input_ids: Tensor,
46
+ attention_mask: Tensor = None,
47
+ # RoBERTa variants don't have token_type_ids, so this argument is optional
48
+ token_type_ids: Tensor = None,
49
+ ) -> Tensor:
50
+ # shape of input_ids: (batch_size, seq_len)
51
+ # shape of attention_mask: (batch_size, seq_len)
52
+ outputs: BaseModelOutputWithPoolingAndCrossAttentions = self.backbone(
53
+ input_ids=input_ids,
54
+ attention_mask=attention_mask,
55
+ token_type_ids=token_type_ids,
56
+ )
57
+
58
+ emb = outputs.last_hidden_state[:, 0]
59
+
60
+ if self.training:
61
+ emb = self.dense(emb)
62
+ emb = self.activation(emb)
63
+
64
+ return emb
65
+
66
+ def show_embedding_score(tokenizer, model, sentences):
67
+ inputs = tokenizer(sentences, padding=True, truncation=True, return_tensors="pt")
68
+ embeddings = model(**inputs)
69
+ score01 = cal_score(embeddings[0,:], embeddings[1,:])
70
+ score02 = cal_score(embeddings[0,:], embeddings[2,:])
71
+ print(score01, score02)
72
+
73
+ def cal_score(a, b):
74
+ if len(a.shape) == 1: a = a.unsqueeze(0)
75
+ if len(b.shape) == 1: b = b.unsqueeze(0)
76
+ a_norm = a / a.norm(dim=1)[:, None]
77
+ b_norm = b / b.norm(dim=1)[:, None]
78
+ return torch.mm(a_norm, b_norm.transpose(0, 1)) * 100
79
+
80
+ # Load pre-trained model
81
+ model = SimCSEModel.from_pretrained("daekeun-ml/KoSimCSE-unsupervised-roberta-large")
82
+ tokenizer = AutoTokenizer.from_pretrained("daekeun-ml/KoSimCSE-unsupervised-roberta-large")
83
+
84
+ # Inference example
85
+ sentences = ['์ด๋ฒˆ ์ฃผ ์ผ์š”์ผ์— ๋ถ„๋‹น ์ด๋งˆํŠธ ์ ์€ ๋ฌธ์„ ์—ฌ๋‚˜์š”?',
86
+ '์ผ์š”์ผ์— ๋ถ„๋‹น ์ด๋งˆํŠธ๋Š” ๋ฌธ ์—ด์–ด์š”?',
87
+ '๋ถ„๋‹น ์ด๋งˆํŠธ ์ ์€ ํ† ์š”์ผ์— ๋ช‡ ์‹œ๊นŒ์ง€ ํ•˜๋‚˜์š”']
88
+
89
+ show_embedding_score(tokenizer, model.cpu(), sentences)
90
+ ```
91
+
92
+
93
+ ## Introduction
94
+
95
+ [SimCSE](https://aclanthology.org/2021.emnlp-main.552/) is a highly efficient and innovative embedding technique based on the concept of contrastive learning. Unsupervised learning can be performed without the need to prepare ground-truth labels, and high-performance supervised learning can be performed if a good NLI (Natural Language Inference) dataset is prepared. The concept is very simple and the psudeo-code is intuitive, so the implementation is not difficult, but I have seen many people still struggle to train this model.
96
+
97
+ The official implementation code from the authors of the paper is publicly available, but it is not suitable for a step-by-step implementation. Therefore, we have reorganized the code based on [Simple-SIMCSE's GitHub](https://github.com/hppRC/simple-simcse) so that even ML beginners can train the model from the scratch with a step-by-step implementation. It's minimalist code for beginners, but data scientists and ML engineers can also make good use of it.
98
+
99
+ ### Added over Simple-SimCSE
100
+ - Added the Supervised Learning part, which shows you step-by-step how to construct the training dataset.
101
+ - Added Distributed Learning Logic. If you have a multi-GPU setup, you can train faster.
102
+ - Added SageMaker Training. `ml.g4dn.xlarge` trains well, but we recommend `ml.g4dn.12xlarge` or` ml.g5.12xlarge` for faster training.
103
+
104
+ ## Requirements
105
+ We recommend preparing an Amazon SageMaker instance with the specifications below to perform this hands-on.
106
+
107
+ ### SageMaker Notebook instance
108
+ - `ml.g4dn.xlarge`
109
+
110
+ ### SageMaker Training instance
111
+ - `ml.g4dn.xlarge` (Minimum)
112
+ - `ml.g5.12xlarge` (Recommended)
113
+
114
+ ## Datasets
115
+
116
+ For supervised learning, you need an NLI dataset that specifies the relationship between the two sentences. For unsupervised learning, we recommend using wikipedia raw data separated into sentences. This hands-on uses the dataset registered with huggingface, but you can also configure your own dataset.
117
+
118
+ The datasets used in this hands-on are as follows
119
+
120
+ #### Supervised
121
+ - [Klue-NLI](https://huggingface.co/datasets/klue/viewer/nli/)
122
+ - [Kor-NLI](https://huggingface.co/datasets/kor_nli)
123
+
124
+ #### Unsupervised
125
+ - [kowiki-sentences](https://huggingface.co/datasets/heegyu/kowiki-sentences): Data from 20221001 Korean wiki split into sentences using kss (backend=mecab) morphological analyzer.
126
+
127
+ ## How to train
128
+ - See https://github.com/daekeun-ml/KoSimCSE-SageMaker
129
+
130
+ ## Performance
131
+ We trained with parameters similar to those in the paper and did not perform any parameter tuning. Higher max sequence length does not guarantee higher performance; building a good NLI dataset is more important
132
+
133
+ ```json
134
+ {
135
+ "batch_size": 64,
136
+ "num_epochs": 1 (for unsupervised training), 3 (for supervised training)
137
+ "lr": 3e-05,
138
+ "num_warmup_steps": 0,
139
+ "temperature": 0.05,
140
+ "lr_scheduler_type": "linear",
141
+ "max_seq_len": 32,
142
+ "use_fp16": "True",
143
+ }
144
+ ```
145
+
146
+
147
+ ### KLUE-STS
148
+ | Model | Avg | Cosine Pearson | Cosine Spearman | Euclidean Pearson | Euclidean Spearman | Manhattan Pearson | Manhattan Spearman | Dot Pearson | Dot Spearman |
149
+ |------------------------|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|
150
+ | KoSimCSE-RoBERTa-base (Unsupervised) | 81.17 | 81.27 | 80.96 | 81.70 | 80.97 | 81.63 | 80.89 | 81.12 | 80.81 |
151
+ | KoSimCSE-RoBERTa-base (Supervised) | 84.19 | 83.04 | 84.46 | 84.97 | 84.50 | 84.95 | 84.45 | 82.88 | 84.28 |
152
+ | KoSimCSE-RoBERTa-large (Unsupervised) | 81.96 | 82.09 | 81.71 | 82.45 | 81.73 | 82.42 | 81.69 | 81.98 | 81.58 |
153
+ | KoSimCSE-RoBERTa-large (Supervised) | 85.37 | 84.38 | 85.99 | 85.97 | 85.81 | 86.00 | 85.79 | 83.87 | 85.15 |
154
+
155
+ ### Kor-STS
156
+ | Model | Avg | Cosine Pearson | Cosine Spearman | Euclidean Pearson | Euclidean Spearman | Manhattan Pearson | Manhattan Spearman | Dot Pearson | Dot Spearman |
157
+ |------------------------|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|
158
+ | KoSimCSE-RoBERTa-base (Unsupervised) | 81.20 | 81.53 | 81.17 | 80.89 | 81.20 | 80.93 | 81.22 | 81.48 | 81.14 |
159
+ | KoSimCSE-RoBERTa-base (Supervised) | 85.33 | 85.16 | 85.46 | 85.37 | 85.45 | 85.31 | 85.37 | 85.13 | 85.41 |
160
+ | KoSimCSE-RoBERTa-large (Unsupervised) | 81.71 | 82.10 | 81.78 | 81.12 | 81.78 | 81.15 | 81.80 | 82.15 | 81.80 |
161
+ | KoSimCSE-RoBERTa-large (Supervised) | 85.54 | 85.41 | 85.78 | 85.18 | 85.51 | 85.26 | 85.61 | 85.70 | 85.90 |
162
+
163
+
164
+ ## References
165
+ - Simple-SimCSE: https://github.com/hppRC/simple-simcse
166
+ - KoSimCSE: https://github.com/BM-K/KoSimCSE-SKT
167
+ - SimCSE (official): https://github.com/princeton-nlp/SimCSE
168
+ - SimCSE paper: https://aclanthology.org/2021.emnlp-main.552