File size: 7,370 Bytes
d95e9bb
 
ea5de0e
 
 
 
d95e9bb
ea5de0e
ca3d45e
ea5de0e
 
ece7b2c
010f184
ea5de0e
 
 
2e58d77
ea5de0e
 
 
 
2e58d77
 
01d6325
1b6aaab
2e58d77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b6aaab
2e58d77
01d6325
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e58d77
ea5de0e
 
 
 
 
 
 
 
 
 
 
2e58d77
ea5de0e
 
 
 
 
 
 
2e58d77
ea5de0e
 
 
 
 
100e6f4
 
 
 
 
 
 
 
 
 
 
ea5de0e
 
 
2e58d77
ea5de0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e58d77
ea5de0e
 
 
 
2e58d77
6f11226
ea5de0e
 
 
2e58d77
ea5de0e
 
 
 
2e58d77
ea5de0e
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
---
license: mit
language:
- ko
- en
pipeline_tag: text-classification
---

# Korean Reranker Training on Amazon SageMaker

### **ํ•œ๊ตญ์–ด Reranker** ๊ฐœ๋ฐœ์„ ์œ„ํ•œ ํŒŒ์ธํŠœ๋‹ ๊ฐ€์ด๋“œ๋ฅผ ์ œ์‹œํ•ฉ๋‹ˆ๋‹ค.
ko-reranker๋Š” [BAAI/bge-reranker-larger](https://huggingface.co/BAAI/bge-reranker-large) ๊ธฐ๋ฐ˜ ํ•œ๊ตญ์–ด ๋ฐ์ดํ„ฐ์— ๋Œ€ํ•œ fine-tuned model ์ž…๋‹ˆ๋‹ค. <br>
๋ณด๋‹ค ์ž์„ธํ•œ ์‚ฌํ•ญ์€ [korean-reranker-git](https://github.com/aws-samples/aws-ai-ml-workshop-kr/tree/master/genai/aws-gen-ai-kr/30_fine_tune/reranker-kr)์„ ์ฐธ๊ณ ํ•˜์„ธ์š”

- - -

## 0. Features
- #### <span style="#FF69B4;"> Reranker๋Š” ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ๊ณผ ๋‹ฌ๋ฆฌ ์งˆ๋ฌธ๊ณผ ๋ฌธ์„œ๋ฅผ ์ž…๋ ฅ์œผ๋กœ ์‚ฌ์šฉํ•˜๋ฉฐ ์ž„๋ฒ ๋”ฉ ๋Œ€์‹  ์œ ์‚ฌ๋„๋ฅผ ์ง์ ‘ ์ถœ๋ ฅํ•ฉ๋‹ˆ๋‹ค.</span>
- #### <span style="#FF69B4;"> Reranker์— ์งˆ๋ฌธ๊ณผ ๊ตฌ์ ˆ์„ ์ž…๋ ฅํ•˜๋ฉด ์—ฐ๊ด€์„ฑ ์ ์ˆ˜๋ฅผ ์–ป์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.</span>
- #### <span style="#FF69B4;"> Reranker๋Š” CrossEntropy loss๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ์ตœ์ ํ™”๋˜๋ฏ€๋กœ ๊ด€๋ จ์„ฑ ์ ์ˆ˜๊ฐ€ ํŠน์ • ๋ฒ”์œ„์— ๊ตญํ•œ๋˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.</span>

## 1.Usage

- using Transformers
```
    def exp_normalize(x):
      b = x.max()
      y = np.exp(x - b)
      return y / y.sum()
    
    from transformers import AutoModelForSequenceClassification, AutoTokenizer

    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForSequenceClassification.from_pretrained(model_path)
    model.eval()

    pairs = [["๋‚˜๋Š” ๋„ˆ๋ฅผ ์‹ซ์–ดํ•ด", "๋‚˜๋Š” ๋„ˆ๋ฅผ ์‚ฌ๋ž‘ํ•ด"], \
             ["๋‚˜๋Š” ๋„ˆ๋ฅผ ์ข‹์•„ํ•ด", "๋„ˆ์— ๋Œ€ํ•œ ๋‚˜์˜ ๊ฐ์ •์€ ์‚ฌ๋ž‘ ์ผ ์ˆ˜๋„ ์žˆ์–ด"]]

    with torch.no_grad():
        inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
        scores = model(**inputs, return_dict=True).logits.view(-1, ).float()
        scores = exp_normalize(scores.numpy())
        print (f'first: {scores[0]}, second: {scores[1]}')
```

- using SageMaker
```
import sagemaker
import boto3
from sagemaker.huggingface import HuggingFaceModel

try:
	role = sagemaker.get_execution_role()
except ValueError:
	iam = boto3.client('iam')
	role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

# Hub Model configuration. https://huggingface.co/models
hub = {
	'HF_MODEL_ID':'Dongjin-kr/ko-reranker',
	'HF_TASK':'text-classification'
}

# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
	transformers_version='4.28.1',
	pytorch_version='2.0.0',
	py_version='py310',
	env=hub,
	role=role, 
)

# deploy model to SageMaker Inference
predictor = huggingface_model.deploy(
	initial_instance_count=1, # number of instances
	instance_type='ml.g5.large' # ec2 instance type
)

runtime_client = boto3.Session().client('sagemaker-runtime')
payload = json.dumps(
    {
        "inputs": [
            {"text": "๋‚˜๋Š” ๋„ˆ๋ฅผ ์‹ซ์–ดํ•ด", "text_pair": "๋‚˜๋Š” ๋„ˆ๋ฅผ ์‚ฌ๋ž‘ํ•ด"},
            {"text": "๋‚˜๋Š” ๋„ˆ๋ฅผ ์ข‹์•„ํ•ด", "text_pair": "๋„ˆ์— ๋Œ€ํ•œ ๋‚˜์˜ ๊ฐ์ •์€ ์‚ฌ๋ž‘ ์ผ ์ˆ˜๋„ ์žˆ์–ด"}
        ]
    }
)

response = runtime_client.invoke_endpoint(
    EndpointName="<endpoint-name>",
    ContentType="application/json",
    Accept=application/json",
    Body=payload
)

## deserialization
out = json.loads(response['Body'].read().decode()) ## for json
print (f'Response: {out}')

```

## 2. Backgound
- #### <span style="#FF69B4;"> **์ปจํƒ์ŠคํŠธ ์ˆœ์„œ๊ฐ€ ์ •ํ™•๋„์— ์˜ํ–ฅ ์ค€๋‹ค**([Lost in Middel, *Liu et al., 2023*](https://arxiv.org/pdf/2307.03172.pdf)) </span>

- #### <span style="#FF69B4;"> [Reranker ์‚ฌ์šฉํ•ด์•ผ ํ•˜๋Š” ์ด์œ ](https://www.pinecone.io/learn/series/rag/rerankers/)</span>
    - ํ˜„์žฌ LLM์€ context ๋งŽ์ด ๋„ฃ๋Š”๋‹ค๊ณ  ์ข‹์€๊ฑฐ ์•„๋‹˜, relevantํ•œ๊ฒŒ ์ƒ์œ„์— ์žˆ์–ด์•ผ ์ •๋‹ต์„ ์ž˜ ๋งํ•ด์ค€๋‹ค
    - Semantic search์—์„œ ์‚ฌ์šฉํ•˜๋Š” similarity(relevant) score๊ฐ€ ์ •๊ตํ•˜์ง€ ์•Š๋‹ค. (์ฆ‰, ์ƒ์œ„ ๋žญ์ปค๋ฉด ํ•˜์œ„ ๋žญ์ปค๋ณด๋‹ค ํ•ญ์ƒ ๋” ์งˆ๋ฌธ์— ์œ ์‚ฌํ•œ ์ •๋ณด๊ฐ€ ๋งž์•„?) 
        * Embedding์€ meaning behind document๋ฅผ ๊ฐ€์ง€๋Š” ๊ฒƒ์— ํŠนํ™”๋˜์–ด ์žˆ๋‹ค. 
        * ์งˆ๋ฌธ๊ณผ ์ •๋‹ต์ด ์˜๋ฏธ์ƒ ๊ฐ™์€๊ฑด ์•„๋‹ˆ๋‹ค. ([Hypothetical Document Embeddings](https://medium.com/prompt-engineering/hyde-revolutionising-search-with-hypothetical-document-embeddings-3474df795af8))
        * ANNs([Approximate Nearest Neighbors](https://towardsdatascience.com/comprehensive-guide-to-approximate-nearest-neighbors-algorithms-8b94f057d6b6)) ์‚ฌ์šฉ์— ๋”ฐ๋ฅธ ํŒจ๋„ํ‹ฐ

- - -

## 3. Reranker models

- #### <span style="#FF69B4;"> [Cohere] [Reranker](https://txt.cohere.com/rerank/)</span>
- #### <span style="#FF69B4;"> [BAAI] [bge-reranker-large](https://huggingface.co/BAAI/bge-reranker-large)</span>
- #### <span style="#FF69B4;"> [BAAI] [bge-reranker-base](https://huggingface.co/BAAI/bge-reranker-base)</span>

- - -

## 4. Dataset

- #### <span style="#FF69B4;"> [msmarco-triplets](https://github.com/microsoft/MSMARCO-Passage-Ranking) </span>
    - (Question, Answer, Negative)-Triplets from MS MARCO Passages dataset, 499,184 samples
    - ํ•ด๋‹น ๋ฐ์ดํ„ฐ ์…‹์€ ์˜๋ฌธ์œผ๋กœ ๊ตฌ์„ฑ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.
    - Amazon Translate ๊ธฐ๋ฐ˜์œผ๋กœ ๋ฒˆ์—ญํ•˜์—ฌ ํ™œ์šฉํ•˜์˜€์Šต๋‹ˆ๋‹ค.
 
- #### <span style="#FF69B4;"> Format </span>
```
{"query": str, "pos": List[str], "neg": List[str]}
```
- Query๋Š” ์งˆ๋ฌธ์ด๊ณ , pos๋Š” ๊ธ์ • ํ…์ŠคํŠธ ๋ชฉ๋ก, neg๋Š” ๋ถ€์ • ํ…์ŠคํŠธ ๋ชฉ๋ก์ž…๋‹ˆ๋‹ค. ์ฟผ๋ฆฌ์— ๋Œ€ํ•œ ๋ถ€์ • ํ…์ŠคํŠธ๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ ์ „์ฒด ๋ง๋ญ‰์น˜์—์„œ ์ผ๋ถ€๋ฅผ ๋ฌด์ž‘์œ„๋กœ ์ถ”์ถœํ•˜์—ฌ ๋ถ€์ • ํ…์ŠคํŠธ๋กœ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

- #### <span style="#FF69B4;"> Examples </span>
```
{"query": "๋Œ€ํ•œ๋ฏผ๊ตญ์˜ ์ˆ˜๋„๋Š”?", "pos": ["๋ฏธ๊ตญ์˜ ์ˆ˜๋„๋Š” ์›Œ์‹ฑํ„ด์ด๊ณ , ์ผ๋ณธ์€ ๋„๊ต์ด๋ฉฐ ํ•œ๊ตญ์€ ์„œ์šธ์ด๋‹ค."], "neg": ["๋ฏธ๊ตญ์˜ ์ˆ˜๋„๋Š” ์›Œ์‹ฑํ„ด์ด๊ณ , ์ผ๋ณธ์€ ๋„๊ต์ด๋ฉฐ ๋ถํ•œ์€ ํ‰์–‘์ด๋‹ค."]}
```
    
- - -

## 5. Performance
| Model                     | has-right-in-contexts | mrr (mean reciprocal rank) |
|:---------------------------|:-----------------:|:--------------------------:|
| without-reranker (default)| 0.93 | 0.80 |
| with-reranker (bge-reranker-large)| 0.95 | 0.84 |
| **with-reranker (fine-tuned using korean)** | **0.96** | **0.87** |

- **evaluation set**:
```code
./dataset/evaluation/eval_dataset.csv
```
- **training parameters**: 

```json
{
    "learning_rate": 5e-6,
    "fp16": True,
    "num_train_epochs": 3,
    "per_device_train_batch_size": 1,
    "gradient_accumulation_steps": 32,
    "train_group_size": 3,
    "max_len": 512,
    "weight_decay": 0.01,
}
```

- - -

## 6. Acknowledgement
- <span style="#FF69B4;"> Part of the code is developed based on [FlagEmbedding](https://github.com/FlagOpen/FlagEmbedding/tree/master?tab=readme-ov-file) and [KoSimCSE-SageMaker](https://github.com/daekeun-ml/KoSimCSE-SageMaker/tree/7de6eefef8f1a646c664d0888319d17480a3ebe5).</span>

- - -

## 7. Citation
- <span style="#FF69B4;"> If you find this repository useful, please consider giving a like โญ and citation</span>

- - -

## 8. Contributors:
- <span style="#FF69B4;"> **Dongjin Jang, Ph.D.** (AWS AI/ML Specislist Solutions Architect) | [Mail](mailto:dongjinj@amazon.com) | [Linkedin](https://www.linkedin.com/in/dongjin-jang-kr/) | [Git](https://github.com/dongjin-ml) | </span>

- - -

## 9. License
- <span style="#FF69B4;"> FlagEmbedding is licensed under the [MIT License](https://github.com/aws-samples/aws-ai-ml-workshop-kr/blob/master/LICENSE). </span>