pavanhitloop's picture
Update app.py
db7d8be
raw
history blame contribute delete
No virus
7.16 kB
import os, sys
# from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, MBartForConditionalGeneration
# import torch
import gradio as gr
import requests
import json
# from huggingface_hub import login
class LTRC_Translation_API():
def __init__(self, url = 'https://ssmt.iiit.ac.in/onemt', src_lang = 'en', tgt_lang = 'te'):
self.lang_map = {'te': 'tel', 'en': 'eng', 'ta': 'tam', 'ml': 'mal', 'mr': 'mar', 'kn': 'kan', 'hi': 'hin'}
self.url = url
self.headers = {
'Content-Type': 'application/json',
'Accept': 'application/json'
}
tgt_lang = self.lang_map.get(tgt_lang, 'te')
src_lang = self.lang_map.get(src_lang, 'en')
self.src_lang = src_lang
self.tgt_lang = tgt_lang
def translate(self, text):
try:
data = {'text': text, 'source_language': self.src_lang, 'target_language': self.tgt_lang}
response = requests.post(self.url, headers = self.headers, json = data)
translated_text = json.loads(response.text).get('data', '')
return translated_text
except Exception as e:
print("Exception: ", e)
return ''
# class Headline_Generation():
# def __init__(self, model_name = "lokeshmadasu42/sample"):
# self.model_name = model_name
# self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# self.tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False, use_fast=False, keep_accents=True)
# self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# self.model.to(self.device)
# self.model.eval()
# self.bos_id = self.tokenizer._convert_token_to_id_with_added_voc("<s>")
# self.eos_id = self.tokenizer._convert_token_to_id_with_added_voc("</s>")
# self.pad_id = self.tokenizer._convert_token_to_id_with_added_voc("<pad>")
# self.lang_map = {'as': '<2as>', 'bn': '<2bn>', 'en': '<2en>', 'gu': '<2gu>', 'hi': '<2hi>', 'kn': '<2kn>', 'ml': '<2ml>', 'mr': '<2mr>', 'or': '<2or>', 'pa': '<2pa>', 'ta': '<2ta>', 'te': '<2te>'}
# print("Headline Generation model loaded...!")
# def get_headline(self, text, lang_id):
# inp = self.tokenizer(text, add_special_tokens=False, return_tensors="pt", padding=True).to(self.device)
# inp = inp['input_ids']
# lang_code = self.lang_map.get(lang_id, '')
# text = text + "</s> " + lang_code
# # print("Text: ", text)
# model_output = self.model.generate(
# inp,
# use_cache=True,
# num_beams=5,
# max_length=32,
# min_length=1,
# early_stopping=True,
# pad_token_id = self.pad_id,
# bos_token_id = self.bos_id,
# eos_token_id = self.eos_id,
# decoder_start_token_id = self.tokenizer._convert_token_to_id_with_added_voc(lang_code)
# )
# decoded_output = self.tokenizer.decode(
# model_output[0],
# skip_special_tokens=True,
# clean_up_tokenization_spaces=False
# )
# return decoded_output
# class Summarization():
# def __init__(self, model_name = "ashokurlana/mBART-TeSum"):
# self.model_name = model_name
# self.device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
# self.tokenizer = AutoTokenizer.from_pretrained(model_name)
# self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# self.model.to(self.device)
# self.model.eval()
# self.lang_map = {'te': 'te_IN', 'en': 'en_XX'}
# print("Summarization model loaded...!")
# def get_summary(self, text, lang_id):
# inp = self.tokenizer([text], add_special_tokens=False, return_tensors="pt", max_length = 1024).to(self.device)
# inp = inp['input_ids']
# lang_code = self.lang_map.get(lang_id, '')
# model_output = self.model.generate(
# inp,
# use_cache=True,
# num_beams=5,
# max_length=256,
# early_stopping=True
# )
# decoded_output = [self.tokenizer.decode(
# summ_id,
# skip_special_tokens=True,
# clean_up_tokenization_spaces=False
# ) for summ_id in model_output]
# return " ".join(decoded_output)
def get_prediction(text, src_lang_id, tgt_lang_id, translate = False):
# if len(sys.argv)<3:
# print("Usage: python app.py <text_file_path> <lang_id>")
# print("Text file should contain the article news")
# exit()
# txt_path = sys.argv[1]
# lang_id = sys.argv[2]
# if not os.path.exists(txt_path):
# print("Path: {} do not exists".format(txt_path))
# exit()
# text = ''
# with open(txt_path, 'r', encoding='utf-8') as fp:
# text = fp.read().strip()
### Login to huggingface token
# access_token = "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"
# login(token = access_token)
# headline_generator = Headline_Generation()
# summarizer = Summarization()
# if translate == True:
# translator = LTRC_Translation_API(tgt_lang = lang_id)
# text = translator.translate(text)
# headline = headline_generator.get_headline(text, lang_id)
# summary = summarizer.get_summary(text, lang_id)
# print("Article: ", text)
# print("Summary: ", summary)
# print("Headline: ", headline)
# return "Headline: " + headline + "\nSummary: " + summary
# return [text, summary, headline]
translator = LTRC_Translation_API(src_lang = src_lang_id, tgt_lang = tgt_lang_id)
text = translator.translate(text)
return text
interface = gr.Interface(
get_prediction,
inputs=[
gr.Textbox(lines = 8, label = "Source Text", info = "Provide the news article text here"),
# gr.Textbox(lines = 8, label = "News Article Text", info = "Provide the news article text here. Check the `Translate` if the source language is english."),
# gr.Dropdown(
# ['as', 'bn', 'en', 'gu', 'hi', 'kn', 'ml', 'mr', 'or', 'pa', 'ta', 'te'], label="Source Language code", info="select the source language code"
# ),
# gr.Dropdown(
# ['as', 'bn', 'en', 'gu', 'hi', 'kn', 'ml', 'mr', 'or', 'pa', 'ta', 'te'], label="Target Language code", info="select the target language code"
# ),
gr.Dropdown(
['en', 'hi', 'kn', 'ml', 'mr', 'ta', 'te'], label="Source Language code", info="select the source language code"
),
gr.Dropdown(
['en', 'hi', 'kn', 'ml', 'mr', 'ta', 'te'], label="Target Language code", info="select the target language code"
),
# gr.Checkbox(label="Translate", info="Is translation required?")
],
outputs=[
gr.Textbox(lines = 8, label = "Translation", info = "Translated text"),
# gr.Textbox(lines = 8, label = "Source Article Text", info = "Source article text (if `Translate` is enabled then the source will be translated to target language)"),
# gr.Textbox(lines = 4, label = "Summary", info = "Summary of the given article (translated if `Translate` is enabled)"),
# gr.Textbox(lines = 2, label = "Headline", info = "Generated headline of the given article (translated if `Translate` is enabled)")
],
title = "Indic Translation Demo"
)
interface.launch(share=True)