import time import re from math import floor, ceil from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils # from nltk.tokenize import sent_tokenize from flask import Flask, request, jsonify from flask_cors import CORS, cross_origin import webvtt from io import StringIO from mosestokenizer import MosesSentenceSplitter from indicTrans.inference.engine import Model from punctuate import RestorePuncts from indicnlp.tokenize.sentence_tokenize import sentence_split app = Flask(__name__) cors = CORS(app) app.config['CORS_HEADERS'] = 'Content-Type' indic2en_model = Model(expdir='models/v3/indic-en') en2indic_model = Model(expdir='models/v3/en-indic') m2m_model = Model(expdir='models/m2m') rpunct = RestorePuncts() indic_language_dict = { 'Assamese': 'as', 'Hindi' : 'hi', 'Marathi' : 'mr', 'Tamil' : 'ta', 'Bengali' : 'bn', 'Kannada' : 'kn', 'Oriya' : 'or', 'Telugu' : 'te', 'Gujarati' : 'gu', 'Malayalam' : 'ml', 'Punjabi' : 'pa', } splitter = MosesSentenceSplitter('en') def get_inference_params(): source_language = request.form['source_language'] target_language = request.form['target_language'] if source_language in indic_language_dict and target_language == 'English': model = indic2en_model source_lang = indic_language_dict[source_language] target_lang = 'en' elif source_language == 'English' and target_language in indic_language_dict: model = en2indic_model source_lang = 'en' target_lang = indic_language_dict[target_language] elif source_language in indic_language_dict and target_language in indic_language_dict: model = m2m_model source_lang = indic_language_dict[source_language] target_lang = indic_language_dict[target_language] return model, source_lang, target_lang @app.route('/', methods=['GET']) def main(): return "IndicTrans API" @app.route('/supported_languages', methods=['GET']) @cross_origin() def supported_languages(): return jsonify(indic_language_dict) @app.route("/translate", methods=['POST']) @cross_origin() def infer_indic_en(): model, source_lang, target_lang = get_inference_params() source_text = request.form['text'] start_time = time.time() target_text = model.translate_paragraph(source_text, source_lang, target_lang) end_time = time.time() return {'text':target_text, 'duration':round(end_time-start_time, 2)} @app.route("/translate_vtt", methods=['POST']) @cross_origin() def infer_vtt_indic_en(): start_time = time.time() model, source_lang, target_lang = get_inference_params() source_text = request.form['text'] # vad_segments = request.form['vad_nochunk'] # Assuming it is an array of start & end timestamps vad = webvtt.read_buffer(StringIO(source_text)) source_sentences = [v.text.replace('\r', '').replace('\n', ' ') for v in vad] ## SUMANTH LOGIC HERE ## # for each vad timestamp, do: large_sentence = ' '.join(source_sentences) # only sentences in that time range large_sentence = large_sentence.lower() # split_sents = sentence_split(large_sentence, 'en') # print(split_sents) large_sentence = re.sub(r'[^\w\s]', '', large_sentence) punctuated = rpunct.punctuate(large_sentence, batch_size=32) end_time = time.time() print("Time Taken for punctuation: {} s".format(end_time - start_time)) start_time = time.time() split_sents = splitter([punctuated]) ### Please uncomment # print(split_sents) # output_sentence_punctuated = model.translate_paragraph(punctuated, source_lang, target_lang) output_sents = model.batch_translate(split_sents, source_lang, target_lang) # print(output_sents) # output_sents = split_sents # print(output_sents) # align this to those range of source_sentences in `captions` map_ = {split_sents[i] : output_sents[i] for i in range(len(split_sents))} # print(map_) punct_para = ' '.join(list(map_.keys())) nmt_para = ' '.join(list(map_.values())) nmt_words = nmt_para.split(' ') len_punct = len(punct_para.split(' ')) len_nmt = len(nmt_para.split(' ')) start = 0 for i in range(len(vad)): if vad[i].text == '': continue len_caption = len(vad[i].text.split(' ')) frac = (len_caption / len_punct) # frac = round(frac, 2) req_nmt_size = floor(frac * len_nmt) # print(frac, req_nmt_size) vad[i].text = ' '.join(nmt_words[start:start+req_nmt_size]) # print(vad[i].text) # print(start, req_nmt_size) start += req_nmt_size end_time = time.time() print("Time Taken for translation: {} s".format(end_time - start_time)) # vad.save('aligned.vtt') return { 'text': vad.content, # 'duration':round(end_time-start_time, 2) }