File size: 7,542 Bytes
d7b2919
 
 
 
 
 
 
 
 
 
133436c
d7b2919
 
c826555
133436c
b124b4a
9f7d061
d7b2919
 
 
 
 
 
 
 
 
 
 
 
 
 
 
715d968
 
 
 
 
 
 
d7b2919
 
 
9f7d061
 
 
 
d7b2919
 
 
 
 
 
133436c
d7b2919
133436c
 
 
 
 
c826555
133436c
c826555
 
 
 
133436c
 
 
 
 
9f7d061
d7b2919
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d8aa18e
d7b2919
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f7d061
d7b2919
 
 
 
 
 
c826555
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
715d968
d7b2919
 
715d968
d7b2919
 
 
133436c
 
 
 
 
b19fc1e
715d968
133436c
 
 
 
d7b2919
b124b4a
9f7d061
 
b124b4a
 
d7b2919
9f7d061
d7b2919
 
b124b4a
d7b2919
 
 
 
17b21b3
d7b2919
17b21b3
c826555
 
 
17b21b3
 
 
960a1ed
 
 
 
17b21b3
 
d7b2919
f565257
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
# -*- coding: utf-8 -*-

"""
@Author     : Rong Ye
@Time       : May 2022
@Contact    : yerong@bytedance
@Description:
"""

import os
import traceback
import shutil
import yaml
import re
from pydub import AudioSegment
import gradio as gr
from huggingface_hub import snapshot_download


LANGUAGE_CODES = {
    "German": "de",
    "Spanish": "es",
    "French": "fr",
    "Italian": "it",
    "Netherlands": "nl",
    "Portuguese": "pt",
    "Romanian": "ro",
    "Russian": "ru",
}

LANG_GEN_SETUPS = {
    "de": {"beam": 10, "lenpen": 0.7},
    "es": {"beam": 10, "lenpen": 0.1},
    "fr": {"beam": 10, "lenpen": 1.0},
    "it": {"beam": 10, "lenpen": 0.5},
    "nl": {"beam": 10, "lenpen": 0.4},
    "pt": {"beam": 10, "lenpen": 0.9},
    "ro": {"beam": 10, "lenpen": 1.0},
    "ru": {"beam": 10, "lenpen": 0.3},
}

os.system("git clone https://github.com/ReneeYe/ConST")
os.system("mv ConST ConST_git")
os.system('mv -n ConST_git/* ./')
os.system("rm -rf ConST_git")
os.system("pip3 install --editable ./")
os.system("mkdir -p data checkpoint")


huggingface_model_dir = snapshot_download(repo_id="ReneeYe/ConST_en2x_models")
print(huggingface_model_dir)


def convert_audio_to_16k_wav(audio_input):
    sound = AudioSegment.from_file(audio_input)
    sample_rate = sound.frame_rate
    num_channels = sound.channels
    num_frames = int(sound.frame_count())
    filename = audio_input.split("/")[-1]
    print("original file is at:", audio_input)
    if (num_channels > 1) or (sample_rate != 16000): # convert to mono-channel 16k wav
        if num_channels > 1:
            sound = sound.set_channels(1)
        if sample_rate != 16000:
            sound = sound.set_frame_rate(16000)
        num_frames = int(sound.frame_count())
        filename = filename.replace(".wav", "") + "_16k.wav"
        sound.export(f"data/{filename}", format="wav")
    else:
        shutil.copy(audio_input, f'data/{filename}')
    return filename, num_frames


def prepare_tsv(file_name, n_frame, language, task="ST"):
    tgt_lang = LANGUAGE_CODES[language]
    with open("data/test_case.tsv", "w") as f:
        f.write("id\taudio\tn_frames\ttgt_text\tspeaker\tsrc_lang\ttgt_lang\tsrc_text\n")
        f.write(f"sample\t{file_name}\t{n_frame}\tThis is in {tgt_lang}.\tspk.1\ten\t{tgt_lang}\tThis is English.\n")


def get_vocab_and_yaml(language):
    tgt_lang = LANGUAGE_CODES[language]
    # get: spm_ende.model and spm_ende.txt, and save to data/xxx
    # if exist, no need to download
    shutil.copy(os.path.join(huggingface_model_dir, f"vocabulary/spm_en{tgt_lang}.model"), "./data")
    shutil.copy(os.path.join(huggingface_model_dir, f"vocabulary/spm_en{tgt_lang}.txt"), "./data")

    # write yaml file
    abs_path = os.popen("pwd").read().strip()
    yaml_dict = LANG_GEN_SETUPS[tgt_lang]
    yaml_dict["input_channels"] = 1
    yaml_dict["use_audio_input"] = True
    yaml_dict["prepend_tgt_lang_tag"] = True
    yaml_dict["prepend_src_lang_tag"] = True
    yaml_dict["audio_root"] = os.path.join(abs_path, "data")
    yaml_dict["vocab_filename"] = f"spm_en{tgt_lang}.txt"
    yaml_dict["bpe_tokenizer"] = {"bpe": "sentencepiece",
                                  "sentencepiece_model": os.path.join(abs_path, f"data/spm_en{tgt_lang}.model")}
    with open("data/config.yaml", "w") as f:
        yaml.dump(yaml_dict, f)


def get_model(language):
    # download models to checkpoint/xxx
    return os.path.join(huggingface_model_dir, f"models/const_en{LANGUAGE_CODES[language]}.pt")


def generate(model_path):
    os.system(f"python3 fairseq_cli/generate.py data/ --gen-subset test_case --task speech_to_text --prefix-size 1 \
                --max-tokens 4000000 --max-source-positions 4000000 \
                --config-yaml config.yaml  --path {model_path} | tee temp.txt")
    output = os.popen("grep ^D temp.txt | sort -n -k 2 -t '-' | cut -f 3")
    return output.read().strip()


def post_processing(raw_sentence):
    output_sentence = raw_sentence
    if ":" in raw_sentence:
        splited_sent = raw_sentence.split(":")
        if len(splited_sent) == 2:
            prefix = splited_sent[0].strip()
            if len(prefix) <= 3:
                output_sentence = splited_sent[1].strip()
            elif ("(" in prefix) and (")" in prefix):
                bgm = re.findall(r"\(.*?\)", prefix)[0]
                if len(prefix.replace(bgm, "").strip()) <= 3:
                    output_sentence = splited_sent[1].strip()
                elif len(splited_sent[1].strip()) > 8:
                    output_sentence = splited_sent[1].strip()

    elif ("(" in raw_sentence) and (")" in raw_sentence):
        bgm_list = re.findall(r"\(.*?\)", raw_sentence)
        for bgm in bgm_list:
            if len(raw_sentence.replace(bgm, "").strip()) > 5:
                output_sentence = output_sentence.replace(bgm, "").strip()
        if len(output_sentence) <= 5:
            output_sentence = raw_sentence
    return output_sentence


def remove_temp_files(audio_file):
    os.remove("temp.txt")
    os.remove("data/test_case.tsv")
    os.remove(f"data/{audio_file}")


def run(audio_file, language):
    try:
        converted_audio_file, n_frame = convert_audio_to_16k_wav(audio_file)
        prepare_tsv(converted_audio_file, n_frame, language)
        get_vocab_and_yaml(language)
        model_path = get_model(language)
        generated_output = post_processing(generate(model_path))
        remove_temp_files(converted_audio_file)
        return generated_output
    except:
        traceback.print_exc()
        return error_output(language)


def error_output(language):
    return f"Fail to translate the audio into {language}, you may use the examples I provide."


inputs = [
        gr.inputs.Audio(source="microphone", type="filepath", label="Record something (in English)..."),
        gr.inputs.Dropdown(list(LANGUAGE_CODES.keys()), default="German", label="From English to Languages X..."),
    ]

iface = gr.Interface(
    fn=run,
    inputs=inputs,
    outputs=[gr.outputs.Textbox(label="The translation")],
    examples=[['short-case.wav', "German"], ['long-case.wav', "German"]],
    title="ConST: an end-to-end speech translator",
    description='ConST is an end-to-end speech-to-text translation model, whose algorithm corresponds to the '
                'NAACL 2022 paper *"Cross-modal Contrastive Learning for Speech Translation"* (see the paper at https://arxiv.org/abs/2205.02444 for more details). '
                'This is a live demo for ConST, to translate English into eight European languages. \n'
                'p.s. For better experience, we recommend using **Chrome** to record audio.',
    article="- The motivation of the ConST model is to use the contrastive learning method to learn similar representations for semantically similar speech and text, " \
            "thus leveraging MT to help improve ST performance. \n"
            "- The models you are experiencing are trained based on the MuST-C dataset (https://ict.fbk.eu/must-c/), " \
            "which only contains about 250k parallel data at each translation direction. "
            "The translation performance of these language directions varies from 20-30+ BLEU, "
            "so it is normal to find some flaws in the translation, and we are trying to improve the models, "
            "such as training on larger datasets and developing more advanced algorithms.\n"
            "- If you want to know how to train the models, you may refer to https://github.com/ReneeYe/ConST.",
    theme="peach",
)
iface.launch()