Upload 2 files
Browse files- README.md +188 -1
- gitattributes +36 -0
README.md
CHANGED
@@ -1,3 +1,190 @@
|
|
1 |
---
|
2 |
-
license:
|
|
|
|
|
|
|
|
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
license: mit
|
3 |
+
language:
|
4 |
+
- ko
|
5 |
+
- en
|
6 |
+
pipeline_tag: text-classification
|
7 |
---
|
8 |
+
|
9 |
+
# Korean Reranker Training on Amazon SageMaker
|
10 |
+
|
11 |
+
### **ํ๊ตญ์ด Reranker** ๊ฐ๋ฐ์ ์ํ ํ์ธํ๋ ๊ฐ์ด๋๋ฅผ ์ ์ํฉ๋๋ค.
|
12 |
+
ko-reranker๋ [BAAI/bge-reranker-larger](https://huggingface.co/BAAI/bge-reranker-large) ๊ธฐ๋ฐ ํ๊ตญ์ด ๋ฐ์ดํฐ์ ๋ํ fine-tuned model ์
๋๋ค. <br>
|
13 |
+
๋ณด๋ค ์์ธํ ์ฌํญ์ [korean-reranker-git](https://github.com/aws-samples/aws-ai-ml-workshop-kr/tree/master/genai/aws-gen-ai-kr/30_fine_tune/reranker-kr)์ ์ฐธ๊ณ ํ์ธ์
|
14 |
+
|
15 |
+
- - -
|
16 |
+
|
17 |
+
## 0. Features
|
18 |
+
- #### <span style="#FF69B4;"> Reranker๋ ์๋ฒ ๋ฉ ๋ชจ๋ธ๊ณผ ๋ฌ๋ฆฌ ์ง๋ฌธ๊ณผ ๋ฌธ์๋ฅผ ์
๋ ฅ์ผ๋ก ์ฌ์ฉํ๋ฉฐ ์๋ฒ ๋ฉ ๋์ ์ ์ฌ๋๋ฅผ ์ง์ ์ถ๋ ฅํฉ๋๋ค.</span>
|
19 |
+
- #### <span style="#FF69B4;"> Reranker์ ์ง๋ฌธ๊ณผ ๊ตฌ์ ์ ์
๋ ฅํ๋ฉด ์ฐ๊ด์ฑ ์ ์๋ฅผ ์ป์ ์ ์์ต๋๋ค.</span>
|
20 |
+
- #### <span style="#FF69B4;"> Reranker๋ CrossEntropy loss๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ต์ ํ๋๋ฏ๋ก ๊ด๋ จ์ฑ ์ ์๊ฐ ํน์ ๋ฒ์์ ๊ตญํ๋์ง ์์ต๋๋ค.</span>
|
21 |
+
|
22 |
+
## 1.Usage
|
23 |
+
|
24 |
+
- using Transformers
|
25 |
+
```
|
26 |
+
def exp_normalize(x):
|
27 |
+
b = x.max()
|
28 |
+
y = np.exp(x - b)
|
29 |
+
return y / y.sum()
|
30 |
+
|
31 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
32 |
+
|
33 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
34 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_path)
|
35 |
+
model.eval()
|
36 |
+
|
37 |
+
pairs = [["๋๋ ๋๋ฅผ ์ซ์ดํด", "๋๋ ๋๋ฅผ ์ฌ๋ํด"], \
|
38 |
+
["๋๋ ๋๋ฅผ ์ข์ํด", "๋์ ๋ํ ๋์ ๊ฐ์ ์ ์ฌ๋ ์ผ ์๋ ์์ด"]]
|
39 |
+
|
40 |
+
with torch.no_grad():
|
41 |
+
inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
|
42 |
+
scores = model(**inputs, return_dict=True).logits.view(-1, ).float()
|
43 |
+
scores = exp_normalize(scores.numpy())
|
44 |
+
print (f'first: {scores[0]}, second: {scores[1]}')
|
45 |
+
```
|
46 |
+
|
47 |
+
- using SageMaker
|
48 |
+
```
|
49 |
+
import sagemaker
|
50 |
+
import boto3
|
51 |
+
from sagemaker.huggingface import HuggingFaceModel
|
52 |
+
|
53 |
+
try:
|
54 |
+
role = sagemaker.get_execution_role()
|
55 |
+
except ValueError:
|
56 |
+
iam = boto3.client('iam')
|
57 |
+
role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']
|
58 |
+
|
59 |
+
# Hub Model configuration. https://huggingface.co/models
|
60 |
+
hub = {
|
61 |
+
'HF_MODEL_ID':'Dongjin-kr/ko-reranker',
|
62 |
+
'HF_TASK':'text-classification'
|
63 |
+
}
|
64 |
+
|
65 |
+
# create Hugging Face Model Class
|
66 |
+
huggingface_model = HuggingFaceModel(
|
67 |
+
transformers_version='4.28.1',
|
68 |
+
pytorch_version='2.0.0',
|
69 |
+
py_version='py310',
|
70 |
+
env=hub,
|
71 |
+
role=role,
|
72 |
+
)
|
73 |
+
|
74 |
+
# deploy model to SageMaker Inference
|
75 |
+
predictor = huggingface_model.deploy(
|
76 |
+
initial_instance_count=1, # number of instances
|
77 |
+
instance_type='ml.g5.large' # ec2 instance type
|
78 |
+
)
|
79 |
+
|
80 |
+
runtime_client = boto3.Session().client('sagemaker-runtime')
|
81 |
+
payload = json.dumps(
|
82 |
+
{
|
83 |
+
"inputs": [
|
84 |
+
{"text": "๋๋ ๋๋ฅผ ์ซ์ดํด", "text_pair": "๋๋ ๋๋ฅผ ์ฌ๋ํด"},
|
85 |
+
{"text": "๋๋ ๋๋ฅผ ์ข์ํด", "text_pair": "๋์ ๋ํ ๋์ ๊ฐ์ ์ ์ฌ๋ ์ผ ์๋ ์์ด"}
|
86 |
+
]
|
87 |
+
}
|
88 |
+
)
|
89 |
+
|
90 |
+
response = runtime_client.invoke_endpoint(
|
91 |
+
EndpointName="<endpoint-name>",
|
92 |
+
ContentType="application/json",
|
93 |
+
Accept="application/json",
|
94 |
+
Body=payload
|
95 |
+
)
|
96 |
+
|
97 |
+
## deserialization
|
98 |
+
out = json.loads(response['Body'].read().decode()) ## for json
|
99 |
+
print (f'Response: {out}')
|
100 |
+
|
101 |
+
```
|
102 |
+
|
103 |
+
## 2. Backgound
|
104 |
+
- #### <span style="#FF69B4;"> **์ปจํ์คํธ ์์๊ฐ ์ ํ๋์ ์ํฅ ์ค๋ค**([Lost in Middel, *Liu et al., 2023*](https://arxiv.org/pdf/2307.03172.pdf)) </span>
|
105 |
+
|
106 |
+
- #### <span style="#FF69B4;"> [Reranker ์ฌ์ฉํด์ผ ํ๋ ์ด์ ](https://www.pinecone.io/learn/series/rag/rerankers/)</span>
|
107 |
+
- ํ์ฌ LLM์ context ๋ง์ด ๋ฃ๋๋ค๊ณ ์ข์๊ฑฐ ์๋, relevantํ๊ฒ ์์์ ์์ด์ผ ์ ๋ต์ ์ ๋งํด์ค๋ค
|
108 |
+
- Semantic search์์ ์ฌ์ฉํ๋ similarity(relevant) score๊ฐ ์ ๊ตํ์ง ์๋ค. (์ฆ, ์์ ๋ญ์ปค๋ฉด ํ์ ๋ญ์ปค๋ณด๋ค ํญ์ ๋ ์ง๋ฌธ์ ์ ์ฌํ ์ ๋ณด๊ฐ ๋ง์?)
|
109 |
+
* Embedding์ meaning behind document๋ฅผ ๊ฐ์ง๋ ๊ฒ์ ํนํ๋์ด ์๋ค.
|
110 |
+
* ์ง๋ฌธ๊ณผ ์ ๋ต์ด ์๋ฏธ์ ๊ฐ์๊ฑด ์๋๋ค. ([Hypothetical Document Embeddings](https://medium.com/prompt-engineering/hyde-revolutionising-search-with-hypothetical-document-embeddings-3474df795af8))
|
111 |
+
* ANNs([Approximate Nearest Neighbors](https://towardsdatascience.com/comprehensive-guide-to-approximate-nearest-neighbors-algorithms-8b94f057d6b6)) ์ฌ์ฉ์ ๋ฐ๋ฅธ ํจ๋ํฐ
|
112 |
+
|
113 |
+
- - -
|
114 |
+
|
115 |
+
## 3. Reranker models
|
116 |
+
|
117 |
+
- #### <span style="#FF69B4;"> [Cohere] [Reranker](https://txt.cohere.com/rerank/)</span>
|
118 |
+
- #### <span style="#FF69B4;"> [BAAI] [bge-reranker-large](https://huggingface.co/BAAI/bge-reranker-large)</span>
|
119 |
+
- #### <span style="#FF69B4;"> [BAAI] [bge-reranker-base](https://huggingface.co/BAAI/bge-reranker-base)</span>
|
120 |
+
|
121 |
+
- - -
|
122 |
+
|
123 |
+
## 4. Dataset
|
124 |
+
|
125 |
+
- #### <span style="#FF69B4;"> [msmarco-triplets](https://github.com/microsoft/MSMARCO-Passage-Ranking) </span>
|
126 |
+
- (Question, Answer, Negative)-Triplets from MS MARCO Passages dataset, 499,184 samples
|
127 |
+
- ํด๋น ๋ฐ์ดํฐ ์
์ ์๋ฌธ์ผ๋ก ๊ตฌ์ฑ๋์ด ์์ต๋๋ค.
|
128 |
+
- Amazon Translate ๊ธฐ๋ฐ์ผ๋ก ๋ฒ์ญํ์ฌ ํ์ฉํ์์ต๋๋ค.
|
129 |
+
|
130 |
+
- #### <span style="#FF69B4;"> Format </span>
|
131 |
+
```
|
132 |
+
{"query": str, "pos": List[str], "neg": List[str]}
|
133 |
+
```
|
134 |
+
- Query๋ ์ง๋ฌธ์ด๊ณ , pos๋ ๊ธ์ ํ
์คํธ ๋ชฉ๋ก, neg๋ ๋ถ์ ํ
์คํธ ๋ชฉ๋ก์
๋๋ค. ์ฟผ๋ฆฌ์ ๋ํ ๋ถ์ ํ
์คํธ๊ฐ ์๋ ๊ฒฝ์ฐ ์ ์ฒด ๋ง๋ญ์น์์ ์ผ๋ถ๋ฅผ ๋ฌด์์๋ก ์ถ์ถํ์ฌ ๋ถ์ ํ
์คํธ๋ก ์ฌ์ฉํ ์ ์์ต๋๋ค.
|
135 |
+
|
136 |
+
- #### <span style="#FF69B4;"> Example </span>
|
137 |
+
```
|
138 |
+
{"query": "๋ํ๋ฏผ๊ตญ์ ์๋๋?", "pos": ["๋ฏธ๊ตญ์ ์๋๋ ์์ฑํด์ด๊ณ , ์ผ๋ณธ์ ๋๊ต์ด๋ฉฐ ํ๊ตญ์ ์์ธ์ด๋ค."], "neg": ["๋ฏธ๊ตญ์ ์๋๋ ์์ฑํด์ด๊ณ , ์ผ๋ณธ์ ๋๊ต์ด๋ฉฐ ๋ถํ์ ํ์์ด๋ค."]}
|
139 |
+
```
|
140 |
+
|
141 |
+
- - -
|
142 |
+
|
143 |
+
## 5. Performance
|
144 |
+
| Model | has-right-in-contexts | mrr (mean reciprocal rank) |
|
145 |
+
|:---------------------------|:-----------------:|:--------------------------:|
|
146 |
+
| without-reranker (default)| 0.93 | 0.80 |
|
147 |
+
| with-reranker (bge-reranker-large)| 0.95 | 0.84 |
|
148 |
+
| **with-reranker (fine-tuned using korean)** | **0.96** | **0.87** |
|
149 |
+
|
150 |
+
- **evaluation set**:
|
151 |
+
```code
|
152 |
+
./dataset/evaluation/eval_dataset.csv
|
153 |
+
```
|
154 |
+
- **training parameters**:
|
155 |
+
|
156 |
+
```json
|
157 |
+
{
|
158 |
+
"learning_rate": 5e-6,
|
159 |
+
"fp16": True,
|
160 |
+
"num_train_epochs": 3,
|
161 |
+
"per_device_train_batch_size": 1,
|
162 |
+
"gradient_accumulation_steps": 32,
|
163 |
+
"train_group_size": 3,
|
164 |
+
"max_len": 512,
|
165 |
+
"weight_decay": 0.01,
|
166 |
+
}
|
167 |
+
```
|
168 |
+
|
169 |
+
- - -
|
170 |
+
|
171 |
+
## 6. Acknowledgement
|
172 |
+
- <span style="#FF69B4;"> Part of the code is developed based on [FlagEmbedding](https://github.com/FlagOpen/FlagEmbedding/tree/master?tab=readme-ov-file) and [KoSimCSE-SageMaker](https://github.com/daekeun-ml/KoSimCSE-SageMaker/tree/7de6eefef8f1a646c664d0888319d17480a3ebe5).</span>
|
173 |
+
|
174 |
+
- - -
|
175 |
+
|
176 |
+
## 7. Citation
|
177 |
+
- <span style="#FF69B4;"> If you find this repository useful, please consider giving a like โญ and citation</span>
|
178 |
+
|
179 |
+
- - -
|
180 |
+
|
181 |
+
## 8. Contributors:
|
182 |
+
- <span style="#FF69B4;"> **Dongjin Jang, Ph.D.** (AWS AI/ML Specislist Solutions Architect) | [Mail](mailto:dongjinj@amazon.com) | [Linkedin](https://www.linkedin.com/in/dongjin-jang-kr/) | [Git](https://github.com/dongjin-ml) | </span>
|
183 |
+
|
184 |
+
- - -
|
185 |
+
|
186 |
+
## 9. License
|
187 |
+
- <span style="#FF69B4;"> FlagEmbedding is licensed under the [MIT License](https://github.com/aws-samples/aws-ai-ml-workshop-kr/blob/master/LICENSE). </span>
|
188 |
+
|
189 |
+
|
190 |
+
|
gitattributes
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|