File size: 4,894 Bytes
7edceed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import time

import re
from math import floor, ceil
from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils
# from nltk.tokenize import sent_tokenize
from flask import Flask, request, jsonify
from flask_cors import CORS, cross_origin
import webvtt
from io import StringIO
from mosestokenizer import MosesSentenceSplitter

from indicTrans.inference.engine import Model
from punctuate import RestorePuncts
from indicnlp.tokenize.sentence_tokenize import sentence_split

app = Flask(__name__)
cors = CORS(app)
app.config['CORS_HEADERS'] = 'Content-Type'

indic2en_model = Model(expdir='models/v3/indic-en')
en2indic_model = Model(expdir='models/v3/en-indic')
m2m_model = Model(expdir='models/m2m')

rpunct = RestorePuncts()

indic_language_dict = {
    'Assamese': 'as',
    'Hindi' : 'hi',
    'Marathi' : 'mr',
    'Tamil' : 'ta',
    'Bengali' : 'bn',
    'Kannada' : 'kn',
    'Oriya' : 'or',
    'Telugu' : 'te',
    'Gujarati' : 'gu',
    'Malayalam' : 'ml',
    'Punjabi' : 'pa',
}

splitter = MosesSentenceSplitter('en')

def get_inference_params():
    source_language = request.form['source_language']
    target_language = request.form['target_language']

    if source_language in indic_language_dict and target_language == 'English':
        model = indic2en_model
        source_lang = indic_language_dict[source_language]
        target_lang = 'en'
    elif source_language == 'English' and target_language in indic_language_dict:
        model = en2indic_model
        source_lang = 'en'
        target_lang = indic_language_dict[target_language]
    elif source_language in indic_language_dict and target_language in indic_language_dict:
        model = m2m_model
        source_lang = indic_language_dict[source_language]
        target_lang = indic_language_dict[target_language]
    
    return model, source_lang, target_lang

@app.route('/', methods=['GET'])
def main():
    return "IndicTrans API"

@app.route('/supported_languages', methods=['GET'])
@cross_origin()
def supported_languages():
    return jsonify(indic_language_dict)

@app.route("/translate", methods=['POST'])
@cross_origin()
def infer_indic_en():
    model, source_lang, target_lang = get_inference_params()
    source_text = request.form['text']

    start_time = time.time()
    target_text = model.translate_paragraph(source_text, source_lang, target_lang)
    end_time = time.time()
    return {'text':target_text, 'duration':round(end_time-start_time, 2)}

@app.route("/translate_vtt", methods=['POST'])
@cross_origin()
def infer_vtt_indic_en():
    start_time = time.time()
    model, source_lang, target_lang = get_inference_params()
    source_text = request.form['text']
    # vad_segments = request.form['vad_nochunk'] # Assuming it is an array of start & end timestamps

    vad = webvtt.read_buffer(StringIO(source_text))
    source_sentences = [v.text.replace('\r', '').replace('\n', ' ') for v in vad]

    ## SUMANTH LOGIC HERE ##

    # for each vad timestamp, do:
    large_sentence = ' '.join(source_sentences) # only sentences in that time range
    large_sentence = large_sentence.lower()
    # split_sents = sentence_split(large_sentence, 'en')
    # print(split_sents)

    large_sentence = re.sub(r'[^\w\s]', '', large_sentence)
    punctuated = rpunct.punctuate(large_sentence, batch_size=32)
    end_time = time.time()
    print("Time Taken for punctuation: {} s".format(end_time - start_time))
    start_time = time.time()
    split_sents = splitter([punctuated]) ### Please uncomment


    # print(split_sents)
    # output_sentence_punctuated = model.translate_paragraph(punctuated, source_lang, target_lang)
    output_sents = model.batch_translate(split_sents, source_lang, target_lang)
    # print(output_sents)
    # output_sents = split_sents
    # print(output_sents)
    # align this to those range of source_sentences in `captions`

    map_ = {split_sents[i] : output_sents[i] for i in range(len(split_sents))}
    # print(map_)
    punct_para = ' '.join(list(map_.keys()))
    nmt_para = ' '.join(list(map_.values()))
    nmt_words = nmt_para.split(' ')

    len_punct = len(punct_para.split(' '))
    len_nmt = len(nmt_para.split(' '))

    start = 0
    for i in range(len(vad)):
        if vad[i].text == '':
            continue

        len_caption = len(vad[i].text.split(' '))
        frac = (len_caption / len_punct)
        # frac = round(frac, 2)

        req_nmt_size = floor(frac * len_nmt)
        # print(frac, req_nmt_size)

        vad[i].text = ' '.join(nmt_words[start:start+req_nmt_size])
        # print(vad[i].text)
        # print(start, req_nmt_size)
        start += req_nmt_size

    end_time = time.time()
    
    print("Time Taken for translation: {} s".format(end_time - start_time))

    # vad.save('aligned.vtt')

    return {
        'text': vad.content,
        # 'duration':round(end_time-start_time, 2)
    }