File size: 6,445 Bytes
51636fd
05e69cc
 
 
 
 
 
 
 
 
51636fd
 
05e69cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51636fd
05e69cc
51636fd
 
 
 
 
 
05e69cc
 
 
 
 
 
 
 
 
 
51636fd
05e69cc
51636fd
 
05e69cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
#%%
import argparse
import time
from tqdm import tqdm
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import os
import json
import torch
from dotenv import load_dotenv
#%%


load_dotenv()
from nltk.tokenize import sent_tokenize

wd = os.path.dirname(os.path.realpath(__file__))


class BackTranslatorAugmenter:
    """
    A class that performs BackTranslation in order to do data augmentation.
    For best results we recommend using bottleneck languages (`out_lang`)
    such as russian (ru) and
    spanish (es).

    Example
    -------
    .. code-block:: python

        data_augmenter = BackTranslatorAugmenter(out_lang="es")
        text = "I want to augment this sentence"
        print(text)
        data_augmenter.back_translate(text, verbose=True)

    :param in_lang: the text input language, defaults to "en"
    :type in_lang: str, optional
    :param out_lang: the language to translate with, defaults to "ru"
    :type out_lang: str, optional
    """

    def __init__(self, in_lang="en", out_lang="ru") -> None:
        if torch.cuda.is_available():
            self.device = "cuda"
        else:
            self.device = "cpu"

        self.in_tokenizer = AutoTokenizer.from_pretrained(
            f"Helsinki-NLP/opus-mt-{in_lang}-{out_lang}",
            cache_dir=os.getenv("TRANSFORMERS_CACHE"),
        )
        self.in_model = AutoModelForSeq2SeqLM.from_pretrained(
            f"Helsinki-NLP/opus-mt-{in_lang}-{out_lang}",
            cache_dir=os.getenv("TRANSFORMERS_CACHE"),
        ).to(self.device)
        self.out_tokenizer = AutoTokenizer.from_pretrained(
            f"Helsinki-NLP/opus-mt-{out_lang}-{in_lang}",
            cache_dir=os.getenv("TRANSFORMERS_CACHE"),
        )
        self.out_model = AutoModelForSeq2SeqLM.from_pretrained(
            f"Helsinki-NLP/opus-mt-{out_lang}-{in_lang}",
            cache_dir=os.getenv("TRANSFORMERS_CACHE"),
        ).to(self.device)

    def back_translate(self, text, verbose=False):
        if verbose:
            tic = time.time()
        encoded_text = self.in_tokenizer(
            text, return_tensors="pt", padding=True, truncation=True, return_overflowing_tokens=True
        ).to(self.device)
        if encoded_text['num_truncated_tokens'][0] > 0:
            print('Text is too long ')
            return self.back_translate_long(text,verbose=verbose)
        
        in_generated_ids = self.in_model.generate(inputs=encoded_text['input_ids'],
            attention_mask=encoded_text["attention_mask"])

        in_preds = [
            self.in_tokenizer.decode(
                gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True
            )
            for gen_id in in_generated_ids
        ]
        if verbose:
            print("in_pred : ", in_preds)
        encoded_text = self.out_tokenizer(
            in_preds, return_tensors="pt", padding=True, truncation=True,return_overflowing_tokens=True
        ).to(self.device)
        out_generated_ids = self.out_model.generate(inputs=encoded_text['input_ids'],
            attention_mask=encoded_text["attention_mask"])
        out_preds = [
            self.out_tokenizer.decode(
                gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True
            )
            for gen_id in out_generated_ids
        ]

        if verbose:
            tac = time.time()
            print("out_pred : ", out_preds)
            print("Elapsed time : ", tac - tic)
        return out_preds

    def back_translate_long(self, text, verbose=False):
        sentences = sent_tokenize(text)
        return [" ".join(self.back_translate(sentences, verbose=verbose))]


def do_backtranslation(**args):
    df = pd.read_csv(args["input_data_path"])[:1]
    data_augmenter = BackTranslatorAugmenter(
        in_lang=args["in_lang"], out_lang=args["out_lang"]
    )

    dict_res = {col_name: [] for _, col_name in args["col_map"].items()}

    for i in tqdm(range(0, len(df), args["batch_size"])):
        for old_col, new_col in args["col_map"].items():
            dict_res[new_col] += data_augmenter.back_translate(
                list(df[old_col].iloc[i : i + args["batch_size"]])
            )

    augmented_df = pd.DataFrame(dict_res)
    os.makedirs(os.path.dirname(args["output_data_path"]), exist_ok=True)
    augmented_df.to_csv(args["output_data_path"])


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Back Translate a dataset for better training"
    )
    parser.add_argument(
        "-in_lang",
        type=str,
        default="en",
        help="""the text input language, defaults to "en",
        one can choose between {'es','ru','en','fr','de','pt','zh'}
        but please have a look at https://huggingface.co/Helsinki-NLP to make sure the language
        pair you ask for is available""",
    )

    parser.add_argument(
        "-out_lang",
        type=str,
        default="ru",
        help="The bottleneck language if you want to resume training one can"
        "choose between {'es','ru','en','fr','de','pt','zh'} but please have a "
        "look at https://huggingface.co/Helsinki-NLP to make sure the language"
        "pair you ask for is available",
    )

    parser.add_argument(
        "-input_data_path",
        type=str,
        default=os.path.join(wd, "dataset", "train_neurips_dataset.csv"),
        help="dataset location, please note it should be a CSV file with two"
        'columns : "text" and  "summary"',
    )

    parser.add_argument(
        "-output_data_path",
        type=str,
        default=os.path.join(
            wd, "dataset", "augmented_datas", "augmented_dataset_output.csv"
        ),
        help="augmented dataset output location",
    )

    parser.add_argument(
        "-columns_mapping",
        "--col_map",
        type=json.loads,
        default={"abstract": "text", "tldr": "summary"},
        help="columns names to apply data augmentation on "
        "you have to give a key/value pair dict such that "
        "{'input_column_name1':'output_column_name1'} by default "
        " it is set as  {'abstract': 'text', 'tldr':'summary'}, "
        "if you don't want to change the column names,"
        " please provide a dict such that keys=values ",
    )

    parser.add_argument("-batch_size", type=int, default=25, help="batch_size")

    args = parser.parse_args()
    do_backtranslation(**vars(args))