Shan1023 commited on
Commit
08453bd
โ€ข
1 Parent(s): 49c9017

Upload 2 files

Browse files
Files changed (2) hide show
  1. README.md +188 -1
  2. gitattributes +36 -0
README.md CHANGED
@@ -1,3 +1,190 @@
1
  ---
2
- license: apache-2.0
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ license: mit
3
+ language:
4
+ - ko
5
+ - en
6
+ pipeline_tag: text-classification
7
  ---
8
+
9
+ # Korean Reranker Training on Amazon SageMaker
10
+
11
+ ### **ํ•œ๊ตญ์–ด Reranker** ๊ฐœ๋ฐœ์„ ์œ„ํ•œ ํŒŒ์ธํŠœ๋‹ ๊ฐ€์ด๋“œ๋ฅผ ์ œ์‹œํ•ฉ๋‹ˆ๋‹ค.
12
+ ko-reranker๋Š” [BAAI/bge-reranker-larger](https://huggingface.co/BAAI/bge-reranker-large) ๊ธฐ๋ฐ˜ ํ•œ๊ตญ์–ด ๋ฐ์ดํ„ฐ์— ๋Œ€ํ•œ fine-tuned model ์ž…๋‹ˆ๋‹ค. <br>
13
+ ๋ณด๋‹ค ์ž์„ธํ•œ ์‚ฌํ•ญ์€ [korean-reranker-git](https://github.com/aws-samples/aws-ai-ml-workshop-kr/tree/master/genai/aws-gen-ai-kr/30_fine_tune/reranker-kr)์„ ์ฐธ๊ณ ํ•˜์„ธ์š”
14
+
15
+ - - -
16
+
17
+ ## 0. Features
18
+ - #### <span style="#FF69B4;"> Reranker๋Š” ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ๊ณผ ๋‹ฌ๋ฆฌ ์งˆ๋ฌธ๊ณผ ๋ฌธ์„œ๋ฅผ ์ž…๋ ฅ์œผ๋กœ ์‚ฌ์šฉํ•˜๋ฉฐ ์ž„๋ฒ ๋”ฉ ๋Œ€์‹  ์œ ์‚ฌ๋„๋ฅผ ์ง์ ‘ ์ถœ๋ ฅํ•ฉ๋‹ˆ๋‹ค.</span>
19
+ - #### <span style="#FF69B4;"> Reranker์— ์งˆ๋ฌธ๊ณผ ๊ตฌ์ ˆ์„ ์ž…๋ ฅํ•˜๋ฉด ์—ฐ๊ด€์„ฑ ์ ์ˆ˜๋ฅผ ์–ป์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.</span>
20
+ - #### <span style="#FF69B4;"> Reranker๋Š” CrossEntropy loss๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ์ตœ์ ํ™”๋˜๋ฏ€๋กœ ๊ด€๋ จ์„ฑ ์ ์ˆ˜๊ฐ€ ํŠน์ • ๋ฒ”์œ„์— ๊ตญํ•œ๋˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.</span>
21
+
22
+ ## 1.Usage
23
+
24
+ - using Transformers
25
+ ```
26
+ def exp_normalize(x):
27
+ b = x.max()
28
+ y = np.exp(x - b)
29
+ return y / y.sum()
30
+
31
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
32
+
33
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
34
+ model = AutoModelForSequenceClassification.from_pretrained(model_path)
35
+ model.eval()
36
+
37
+ pairs = [["๋‚˜๋Š” ๋„ˆ๋ฅผ ์‹ซ์–ดํ•ด", "๋‚˜๋Š” ๋„ˆ๋ฅผ ์‚ฌ๋ž‘ํ•ด"], \
38
+ ["๋‚˜๋Š” ๋„ˆ๋ฅผ ์ข‹์•„ํ•ด", "๋„ˆ์— ๋Œ€ํ•œ ๋‚˜์˜ ๊ฐ์ •์€ ์‚ฌ๋ž‘ ์ผ ์ˆ˜๋„ ์žˆ์–ด"]]
39
+
40
+ with torch.no_grad():
41
+ inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
42
+ scores = model(**inputs, return_dict=True).logits.view(-1, ).float()
43
+ scores = exp_normalize(scores.numpy())
44
+ print (f'first: {scores[0]}, second: {scores[1]}')
45
+ ```
46
+
47
+ - using SageMaker
48
+ ```
49
+ import sagemaker
50
+ import boto3
51
+ from sagemaker.huggingface import HuggingFaceModel
52
+
53
+ try:
54
+ role = sagemaker.get_execution_role()
55
+ except ValueError:
56
+ iam = boto3.client('iam')
57
+ role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']
58
+
59
+ # Hub Model configuration. https://huggingface.co/models
60
+ hub = {
61
+ 'HF_MODEL_ID':'Dongjin-kr/ko-reranker',
62
+ 'HF_TASK':'text-classification'
63
+ }
64
+
65
+ # create Hugging Face Model Class
66
+ huggingface_model = HuggingFaceModel(
67
+ transformers_version='4.28.1',
68
+ pytorch_version='2.0.0',
69
+ py_version='py310',
70
+ env=hub,
71
+ role=role,
72
+ )
73
+
74
+ # deploy model to SageMaker Inference
75
+ predictor = huggingface_model.deploy(
76
+ initial_instance_count=1, # number of instances
77
+ instance_type='ml.g5.large' # ec2 instance type
78
+ )
79
+
80
+ runtime_client = boto3.Session().client('sagemaker-runtime')
81
+ payload = json.dumps(
82
+ {
83
+ "inputs": [
84
+ {"text": "๋‚˜๋Š” ๋„ˆ๋ฅผ ์‹ซ์–ดํ•ด", "text_pair": "๋‚˜๋Š” ๋„ˆ๋ฅผ ์‚ฌ๋ž‘ํ•ด"},
85
+ {"text": "๋‚˜๋Š” ๋„ˆ๋ฅผ ์ข‹์•„ํ•ด", "text_pair": "๋„ˆ์— ๋Œ€ํ•œ ๋‚˜์˜ ๊ฐ์ •์€ ์‚ฌ๋ž‘ ์ผ ์ˆ˜๋„ ์žˆ์–ด"}
86
+ ]
87
+ }
88
+ )
89
+
90
+ response = runtime_client.invoke_endpoint(
91
+ EndpointName="<endpoint-name>",
92
+ ContentType="application/json",
93
+ Accept="application/json",
94
+ Body=payload
95
+ )
96
+
97
+ ## deserialization
98
+ out = json.loads(response['Body'].read().decode()) ## for json
99
+ print (f'Response: {out}')
100
+
101
+ ```
102
+
103
+ ## 2. Backgound
104
+ - #### <span style="#FF69B4;"> **์ปจํƒ์ŠคํŠธ ์ˆœ์„œ๊ฐ€ ์ •ํ™•๋„์— ์˜ํ–ฅ ์ค€๋‹ค**([Lost in Middel, *Liu et al., 2023*](https://arxiv.org/pdf/2307.03172.pdf)) </span>
105
+
106
+ - #### <span style="#FF69B4;"> [Reranker ์‚ฌ์šฉํ•ด์•ผ ํ•˜๋Š” ์ด์œ ](https://www.pinecone.io/learn/series/rag/rerankers/)</span>
107
+ - ํ˜„์žฌ LLM์€ context ๋งŽ์ด ๋„ฃ๋Š”๋‹ค๊ณ  ์ข‹์€๊ฑฐ ์•„๋‹˜, relevantํ•œ๊ฒŒ ์ƒ์œ„์— ์žˆ์–ด์•ผ ์ •๋‹ต์„ ์ž˜ ๋งํ•ด์ค€๋‹ค
108
+ - Semantic search์—์„œ ์‚ฌ์šฉํ•˜๋Š” similarity(relevant) score๊ฐ€ ์ •๊ตํ•˜์ง€ ์•Š๋‹ค. (์ฆ‰, ์ƒ์œ„ ๋žญ์ปค๋ฉด ํ•˜์œ„ ๋žญ์ปค๋ณด๋‹ค ํ•ญ์ƒ ๋” ์งˆ๋ฌธ์— ์œ ์‚ฌํ•œ ์ •๋ณด๊ฐ€ ๋งž์•„?)
109
+ * Embedding์€ meaning behind document๋ฅผ ๊ฐ€์ง€๋Š” ๊ฒƒ์— ํŠนํ™”๋˜์–ด ์žˆ๋‹ค.
110
+ * ์งˆ๋ฌธ๊ณผ ์ •๋‹ต์ด ์˜๋ฏธ์ƒ ๊ฐ™์€๊ฑด ์•„๋‹ˆ๋‹ค. ([Hypothetical Document Embeddings](https://medium.com/prompt-engineering/hyde-revolutionising-search-with-hypothetical-document-embeddings-3474df795af8))
111
+ * ANNs([Approximate Nearest Neighbors](https://towardsdatascience.com/comprehensive-guide-to-approximate-nearest-neighbors-algorithms-8b94f057d6b6)) ์‚ฌ์šฉ์— ๋”ฐ๋ฅธ ํŒจ๋„ํ‹ฐ
112
+
113
+ - - -
114
+
115
+ ## 3. Reranker models
116
+
117
+ - #### <span style="#FF69B4;"> [Cohere] [Reranker](https://txt.cohere.com/rerank/)</span>
118
+ - #### <span style="#FF69B4;"> [BAAI] [bge-reranker-large](https://huggingface.co/BAAI/bge-reranker-large)</span>
119
+ - #### <span style="#FF69B4;"> [BAAI] [bge-reranker-base](https://huggingface.co/BAAI/bge-reranker-base)</span>
120
+
121
+ - - -
122
+
123
+ ## 4. Dataset
124
+
125
+ - #### <span style="#FF69B4;"> [msmarco-triplets](https://github.com/microsoft/MSMARCO-Passage-Ranking) </span>
126
+ - (Question, Answer, Negative)-Triplets from MS MARCO Passages dataset, 499,184 samples
127
+ - ํ•ด๋‹น ๋ฐ์ดํ„ฐ ์…‹์€ ์˜๋ฌธ์œผ๋กœ ๊ตฌ์„ฑ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.
128
+ - Amazon Translate ๊ธฐ๋ฐ˜์œผ๋กœ ๋ฒˆ์—ญํ•˜์—ฌ ํ™œ์šฉํ•˜์˜€์Šต๋‹ˆ๋‹ค.
129
+
130
+ - #### <span style="#FF69B4;"> Format </span>
131
+ ```
132
+ {"query": str, "pos": List[str], "neg": List[str]}
133
+ ```
134
+ - Query๋Š” ์งˆ๋ฌธ์ด๊ณ , pos๋Š” ๊ธ์ • ํ…์ŠคํŠธ ๋ชฉ๋ก, neg๋Š” ๋ถ€์ • ํ…์ŠคํŠธ ๋ชฉ๋ก์ž…๋‹ˆ๋‹ค. ์ฟผ๋ฆฌ์— ๋Œ€ํ•œ ๋ถ€์ • ํ…์ŠคํŠธ๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ ์ „์ฒด ๋ง๋ญ‰์น˜์—์„œ ์ผ๋ถ€๋ฅผ ๋ฌด์ž‘์œ„๋กœ ์ถ”์ถœํ•˜์—ฌ ๋ถ€์ • ํ…์ŠคํŠธ๋กœ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
135
+
136
+ - #### <span style="#FF69B4;"> Example </span>
137
+ ```
138
+ {"query": "๋Œ€ํ•œ๋ฏผ๊ตญ์˜ ์ˆ˜๋„๋Š”?", "pos": ["๋ฏธ๊ตญ์˜ ์ˆ˜๋„๋Š” ์›Œ์‹ฑํ„ด์ด๊ณ , ์ผ๋ณธ์€ ๋„๊ต์ด๋ฉฐ ํ•œ๊ตญ์€ ์„œ์šธ์ด๋‹ค."], "neg": ["๋ฏธ๊ตญ์˜ ์ˆ˜๋„๋Š” ์›Œ์‹ฑํ„ด์ด๊ณ , ์ผ๋ณธ์€ ๋„๊ต์ด๋ฉฐ ๋ถํ•œ์€ ํ‰์–‘์ด๋‹ค."]}
139
+ ```
140
+
141
+ - - -
142
+
143
+ ## 5. Performance
144
+ | Model | has-right-in-contexts | mrr (mean reciprocal rank) |
145
+ |:---------------------------|:-----------------:|:--------------------------:|
146
+ | without-reranker (default)| 0.93 | 0.80 |
147
+ | with-reranker (bge-reranker-large)| 0.95 | 0.84 |
148
+ | **with-reranker (fine-tuned using korean)** | **0.96** | **0.87** |
149
+
150
+ - **evaluation set**:
151
+ ```code
152
+ ./dataset/evaluation/eval_dataset.csv
153
+ ```
154
+ - **training parameters**:
155
+
156
+ ```json
157
+ {
158
+ "learning_rate": 5e-6,
159
+ "fp16": True,
160
+ "num_train_epochs": 3,
161
+ "per_device_train_batch_size": 1,
162
+ "gradient_accumulation_steps": 32,
163
+ "train_group_size": 3,
164
+ "max_len": 512,
165
+ "weight_decay": 0.01,
166
+ }
167
+ ```
168
+
169
+ - - -
170
+
171
+ ## 6. Acknowledgement
172
+ - <span style="#FF69B4;"> Part of the code is developed based on [FlagEmbedding](https://github.com/FlagOpen/FlagEmbedding/tree/master?tab=readme-ov-file) and [KoSimCSE-SageMaker](https://github.com/daekeun-ml/KoSimCSE-SageMaker/tree/7de6eefef8f1a646c664d0888319d17480a3ebe5).</span>
173
+
174
+ - - -
175
+
176
+ ## 7. Citation
177
+ - <span style="#FF69B4;"> If you find this repository useful, please consider giving a like โญ and citation</span>
178
+
179
+ - - -
180
+
181
+ ## 8. Contributors:
182
+ - <span style="#FF69B4;"> **Dongjin Jang, Ph.D.** (AWS AI/ML Specislist Solutions Architect) | [Mail](mailto:dongjinj@amazon.com) | [Linkedin](https://www.linkedin.com/in/dongjin-jang-kr/) | [Git](https://github.com/dongjin-ml) | </span>
183
+
184
+ - - -
185
+
186
+ ## 9. License
187
+ - <span style="#FF69B4;"> FlagEmbedding is licensed under the [MIT License](https://github.com/aws-samples/aws-ai-ml-workshop-kr/blob/master/LICENSE). </span>
188
+
189
+
190
+
gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text