Spaces:
Runtime error
Runtime error
File size: 4,894 Bytes
e8aeaf1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import time
import re
from math import floor, ceil
from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils
# from nltk.tokenize import sent_tokenize
from flask import Flask, request, jsonify
from flask_cors import CORS, cross_origin
import webvtt
from io import StringIO
from mosestokenizer import MosesSentenceSplitter
from indicTrans.inference.engine import Model
from punctuate import RestorePuncts
from indicnlp.tokenize.sentence_tokenize import sentence_split
app = Flask(__name__)
cors = CORS(app)
app.config['CORS_HEADERS'] = 'Content-Type'
indic2en_model = Model(expdir='models/v3/indic-en')
en2indic_model = Model(expdir='models/v3/en-indic')
m2m_model = Model(expdir='models/m2m')
rpunct = RestorePuncts()
indic_language_dict = {
'Assamese': 'as',
'Hindi' : 'hi',
'Marathi' : 'mr',
'Tamil' : 'ta',
'Bengali' : 'bn',
'Kannada' : 'kn',
'Oriya' : 'or',
'Telugu' : 'te',
'Gujarati' : 'gu',
'Malayalam' : 'ml',
'Punjabi' : 'pa',
}
splitter = MosesSentenceSplitter('en')
def get_inference_params():
source_language = request.form['source_language']
target_language = request.form['target_language']
if source_language in indic_language_dict and target_language == 'English':
model = indic2en_model
source_lang = indic_language_dict[source_language]
target_lang = 'en'
elif source_language == 'English' and target_language in indic_language_dict:
model = en2indic_model
source_lang = 'en'
target_lang = indic_language_dict[target_language]
elif source_language in indic_language_dict and target_language in indic_language_dict:
model = m2m_model
source_lang = indic_language_dict[source_language]
target_lang = indic_language_dict[target_language]
return model, source_lang, target_lang
@app.route('/', methods=['GET'])
def main():
return "IndicTrans API"
@app.route('/supported_languages', methods=['GET'])
@cross_origin()
def supported_languages():
return jsonify(indic_language_dict)
@app.route("/translate", methods=['POST'])
@cross_origin()
def infer_indic_en():
model, source_lang, target_lang = get_inference_params()
source_text = request.form['text']
start_time = time.time()
target_text = model.translate_paragraph(source_text, source_lang, target_lang)
end_time = time.time()
return {'text':target_text, 'duration':round(end_time-start_time, 2)}
@app.route("/translate_vtt", methods=['POST'])
@cross_origin()
def infer_vtt_indic_en():
start_time = time.time()
model, source_lang, target_lang = get_inference_params()
source_text = request.form['text']
# vad_segments = request.form['vad_nochunk'] # Assuming it is an array of start & end timestamps
vad = webvtt.read_buffer(StringIO(source_text))
source_sentences = [v.text.replace('\r', '').replace('\n', ' ') for v in vad]
## SUMANTH LOGIC HERE ##
# for each vad timestamp, do:
large_sentence = ' '.join(source_sentences) # only sentences in that time range
large_sentence = large_sentence.lower()
# split_sents = sentence_split(large_sentence, 'en')
# print(split_sents)
large_sentence = re.sub(r'[^\w\s]', '', large_sentence)
punctuated = rpunct.punctuate(large_sentence, batch_size=32)
end_time = time.time()
print("Time Taken for punctuation: {} s".format(end_time - start_time))
start_time = time.time()
split_sents = splitter([punctuated]) ### Please uncomment
# print(split_sents)
# output_sentence_punctuated = model.translate_paragraph(punctuated, source_lang, target_lang)
output_sents = model.batch_translate(split_sents, source_lang, target_lang)
# print(output_sents)
# output_sents = split_sents
# print(output_sents)
# align this to those range of source_sentences in `captions`
map_ = {split_sents[i] : output_sents[i] for i in range(len(split_sents))}
# print(map_)
punct_para = ' '.join(list(map_.keys()))
nmt_para = ' '.join(list(map_.values()))
nmt_words = nmt_para.split(' ')
len_punct = len(punct_para.split(' '))
len_nmt = len(nmt_para.split(' '))
start = 0
for i in range(len(vad)):
if vad[i].text == '':
continue
len_caption = len(vad[i].text.split(' '))
frac = (len_caption / len_punct)
# frac = round(frac, 2)
req_nmt_size = floor(frac * len_nmt)
# print(frac, req_nmt_size)
vad[i].text = ' '.join(nmt_words[start:start+req_nmt_size])
# print(vad[i].text)
# print(start, req_nmt_size)
start += req_nmt_size
end_time = time.time()
print("Time Taken for translation: {} s".format(end_time - start_time))
# vad.save('aligned.vtt')
return {
'text': vad.content,
# 'duration':round(end_time-start_time, 2)
}
|