File size: 6,702 Bytes
b0f9788
 
 
 
 
c2ec774
 
b0f9788
 
 
 
 
 
 
1dbe827
 
 
ab5d421
 
1dbe827
b0f9788
1dbe827
ab5d421
b0f9788
 
1dbe827
 
 
 
 
 
b0f9788
1dbe827
b0f9788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab5d421
c2ec774
 
 
ab5d421
b0f9788
ab5d421
 
 
 
b0f9788
 
1dbe827
 
ab5d421
b0f9788
1dbe827
b0f9788
 
 
 
1dbe827
 
 
 
ab5d421
1dbe827
 
 
 
 
 
 
 
 
 
 
 
ab5d421
1dbe827
ab5d421
 
1dbe827
 
 
 
 
 
 
 
 
 
 
 
ab5d421
 
 
 
 
 
 
 
 
 
 
1dbe827
 
ab5d421
b0f9788
1dbe827
 
b0f9788
1dbe827
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0f9788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab5d421
 
 
 
 
b0f9788
 
 
 
 
 
ab5d421
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import argparse
import random
from pathlib import Path
from datasets import load_dataset, Dataset, DatasetDict

MAX_WORDS = 80

def clean_text(text: str) -> str:
    """
    Cleaning function to apply to all sentences in the dataset.
    """
    text = text.strip("-:").strip()
    return text

def _process_and_format_split(
    dataset_split: Dataset,
    max_length_diff: int,
    num_samples: int = None,
    bidirectional: bool = True
) -> Dataset | None:
    """
    Processes a single split of the OPUS-100 dataset into anchor-positive pairs.
    Optionally includes bidirectional pairs for better bilingual training.
    """
    opus_pairs = []
    
    examples_to_process = dataset_split
    if num_samples:
        if num_samples > len(dataset_split):
            print(f"Warning: Requested {num_samples} samples, but split only has {len(dataset_split)}. Using all available samples.")
        examples_to_process = dataset_split.select(range(min(num_samples, len(dataset_split))))

    for example in examples_to_process:
        eng_sentence = example.get("translation", {}).get("en")
        fra_sentence = example.get("translation", {}).get("fr")

        if isinstance(eng_sentence, str) and isinstance(fra_sentence, str) and eng_sentence and fra_sentence:
            eng_sentence = clean_text(eng_sentence)
            fra_sentence = clean_text(fra_sentence)

            # Skip instances where both sentences are the same
            if eng_sentence == fra_sentence:
                continue

            # Skip if word count difference is too large
            len_en = len(eng_sentence.split())
            len_fr = len(fra_sentence.split())
            if abs(len_en - len_fr) > max_length_diff:
                continue

            # Skip if any of the two sentences is above MAX_WORDS
            if len_en > MAX_WORDS or len_fr > MAX_WORDS:
                continue

            # Add EN->FR pair
            opus_pairs.append([eng_sentence, fra_sentence])
            
            # Add FR->EN pair for bidirectional training
            if bidirectional:
                opus_pairs.append([fra_sentence, eng_sentence])

    if not opus_pairs:
        return None
    
    # Shuffle to mix EN->FR and FR->EN pairs
    random.shuffle(opus_pairs)
    return Dataset.from_dict({
        "anchor": [pair[0] for pair in opus_pairs],
        "positive": [pair[1] for pair in opus_pairs],
    })

def prepare_opus100_data(
    num_pairs: int = 1_000_000,
    output_dir: str = "data/en-fr-opus",
    max_length_diff: int = 7,
    bidirectional: bool = True,
) -> None:
    """
    Downloads and prepares the OPUS-100 English-French dataset for training.

    Fetches the 'en-fr' train, validation, and test splits from the 'Helsinki-NLP/opus-100' dataset, 
    processes them into the expected format for contrastive training with sentence-transformers 
    (Dataset with 'anchor' and 'positive' columns, where each row contains a pair of translated 
    sentences), and saves it to a local dir as a DatasetDict.

    Args:
        num_pairs (int): The total number of translation pairs to select from the original training dataset.
                         The validation and test sets are used in their entirety.
                         If bidirectional=True, the number of pairs will be 2x this.
        output_dir (str): The directory where the processed dataset will be saved.
        max_length_diff (int): The maximum allowed difference in word count between anchor and positive.
        bidirectional (bool): Whether to include both EN->FR and FR->EN pairs.
    """

    print(f"Loading dataset from the hub ('Helsinki-NLP/opus-100', 'en-fr' split)...")
    try:
        full_dataset = load_dataset("Helsinki-NLP/opus-100", "en-fr")
        print(f"Successfully loaded dataset with splits: {list(full_dataset.keys())}")
    except Exception as e:
        print(f"Error loading dataset: {e}")
        return

    dataset_dict = DatasetDict()

    print(f"\nProcessing train split, selecting up to {num_pairs} source pairs...")
    if bidirectional:
        print("Creating bidirectional pairs (EN->FR and FR->EN)...")
    
    train_dataset = _process_and_format_split(
        full_dataset["train"], 
        max_length_diff, 
        num_samples=num_pairs,
        bidirectional=bidirectional
    )
    
    if train_dataset:
        dataset_dict["train"] = train_dataset
        print(f"Created train set with {len(train_dataset)} total pairs.")
    else:
        print("Could not create a train set. Exiting.")
        return

    print("\nProcessing validation split...")
    validation_dataset = _process_and_format_split(full_dataset["validation"], max_length_diff)
    if validation_dataset:
        dataset_dict["validation"] = validation_dataset
        print(f"Created validation set with {len(validation_dataset)} pairs.")
    else:
        print("Validation set could not be created or is empty.")

    print("\nProcessing test split...")
    test_dataset = _process_and_format_split(full_dataset["test"], max_length_diff)
    if test_dataset:
        dataset_dict["test"] = test_dataset
        print(f"Created test set with {len(test_dataset)} pairs.")
    else:
        print("Test set could not be created or is empty.")
    
    print("\nFinal Dataset Structure:")
    print(dataset_dict)

    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)
    
    print(f"\nSaving processed dataset to '{output_path}'...")
    dataset_dict.save_to_disk(output_path)
    
    print("Done!")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Prepare OPUS-100 en-fr dataset for SPLADE training."
    )
    parser.add_argument(
        "--num_pairs",
        type=int,
        default=1_000_000,
        help="Number of sentence pairs to process from the dataset.",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="data/en-fr-opus",
        help="Directory to save the processed dataset.",
    )
    parser.add_argument(
        "--max_length_diff",
        type=int,
        default=4,
        help="Maximum allowed difference in word count between anchor and positive. Pairs with a larger difference are excluded.",
    )
    parser.add_argument(
        "--bidirectional",
        action="store_true",
        help="Whether to include both EN->FR and FR->EN pairs.",
    )
    args = parser.parse_args()

    prepare_opus100_data(
        num_pairs=args.num_pairs,
        output_dir=args.output_dir,
        max_length_diff=args.max_length_diff,
        bidirectional=args.bidirectional,
    )

""" To clean:
Sample 282148:
  anchor:   "
  positive: ".
"""