tranhoangnguyen03 commited on
Commit
6886972
·
verified ·
1 Parent(s): 6700fe6

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +97 -2
README.md CHANGED
@@ -1,6 +1,101 @@
1
  ---
2
  license: apache-2.0
3
  ---
4
- TBD
5
 
6
- Credit: giangvo.gt@gmail.com
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
  ---
 
4
 
5
+ ## Generate training data
6
+ ```
7
+ # Function to convert dataframe to list of InputExample
8
+ def df_to_input_examples(df):
9
+ return [
10
+ InputExample(texts=[row['query'],
11
+ row['document']],
12
+ label=float(row['relevance_score']))
13
+ for _, row in df.iterrows()
14
+ ]
15
+
16
+ train_samples = df_to_input_examples(train_df)
17
+ val_samples = df_to_input_examples(val_df)
18
+
19
+ # Create a DataLoader for training
20
+ train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=16)
21
+ ```
22
+
23
+ ## Create Evaluator class
24
+ ```
25
+ # Custom evaluator for CrossEncoder
26
+ class CrossEncoderEvaluator:
27
+ def __init__(self, eval_samples):
28
+ self.eval_samples = eval_samples
29
+
30
+ def __call__(self, model, **kwargs): # Add **kwargs to catch extra arguments
31
+ predictions = model.predict([[sample.texts[0], sample.texts[1]] for sample in self.eval_samples])
32
+ labels = [sample.label for sample in self.eval_samples]
33
+
34
+ pearson_corr, _ = pearsonr(predictions, labels)
35
+ spearman_corr, _ = spearmanr(predictions, labels)
36
+
37
+ return (pearson_corr + spearman_corr) / 2 # Average of Pearson and Spearman correlations
38
+
39
+ # Prepare the evaluator
40
+ evaluator = CrossEncoderEvaluator(val_samples)
41
+ ```
42
+
43
+ ## Train the model
44
+ ```
45
+ # Initialize the cross-encoder model
46
+ model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', num_labels=1)
47
+
48
+ # Train the model
49
+ model.fit(
50
+ train_dataloader=train_dataloader,
51
+ evaluator=evaluator,
52
+ epochs=100,
53
+ warmup_steps=100,
54
+ evaluation_steps=500,
55
+ output_path='fine_tuned_reranker'
56
+ )
57
+ ```
58
+
59
+ ## Usage
60
+ ```
61
+ # Load the fine-tuned reranker
62
+ reranker_model = CrossEncoder('fine_tuned_reranker')
63
+
64
+ def search_and_rerank(query, documents, top_k=10):
65
+ # Prepare pairs for reranking
66
+ pairs = [(query, doc) for doc in documents]
67
+
68
+ # Rerank using fine-tuned cross-encoder
69
+ rerank_scores = reranker_model.predict(pairs)
70
+
71
+ # Sort results by reranker scores
72
+ reranked_results = sorted(
73
+ zip(documents, rerank_scores.tolist()),
74
+ key=lambda x: x[1], reverse=True
75
+ )
76
+
77
+ return reranked_results
78
+
79
+ query = "OPPO 8GB 128G"
80
+ documents = [
81
+ "OPPO Reno11F 5G 8GB-256GB",
82
+ "OPPO Reno11F 5G 8GB-32GB",
83
+ "OPPO Reno11F 5G 16GB-128GB",
84
+ "Samsung galaxy 128GB",
85
+ "Samsung S24 128GB",
86
+ # ...
87
+ ]
88
+
89
+ start_time = time.time()
90
+ results = search_and_rerank(query, documents, len(documents)-1)
91
+ end_time = time.time()
92
+
93
+ execution_time = (end_time - start_time)*1000
94
+ print(f"Execution time: {execution_time:.4f} mili seconds")
95
+
96
+ print(f"Query: \t\t\t\t{query}")
97
+ for res in results:
98
+ print(f"Score: {res[-1]:.4f} | Document: {res[0]}")
99
+ ```
100
+
101
+ Credit goes to: giangvo.gt@gmail.com