import os, sys # from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, MBartForConditionalGeneration # import torch import gradio as gr import requests import json # from huggingface_hub import login class LTRC_Translation_API(): def __init__(self, url = 'https://ssmt.iiit.ac.in/onemt', src_lang = 'en', tgt_lang = 'te'): self.lang_map = {'te': 'tel', 'en': 'eng', 'ta': 'tam', 'ml': 'mal', 'mr': 'mar', 'kn': 'kan', 'hi': 'hin'} self.url = url self.headers = { 'Content-Type': 'application/json', 'Accept': 'application/json' } tgt_lang = self.lang_map.get(tgt_lang, 'te') src_lang = self.lang_map.get(src_lang, 'en') self.src_lang = src_lang self.tgt_lang = tgt_lang def translate(self, text): try: data = {'text': text, 'source_language': self.src_lang, 'target_language': self.tgt_lang} response = requests.post(self.url, headers = self.headers, json = data) translated_text = json.loads(response.text).get('data', '') return translated_text except Exception as e: print("Exception: ", e) return '' # class Headline_Generation(): # def __init__(self, model_name = "lokeshmadasu42/sample"): # self.model_name = model_name # self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # self.tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False, use_fast=False, keep_accents=True) # self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) # self.model.to(self.device) # self.model.eval() # self.bos_id = self.tokenizer._convert_token_to_id_with_added_voc("") # self.eos_id = self.tokenizer._convert_token_to_id_with_added_voc("") # self.pad_id = self.tokenizer._convert_token_to_id_with_added_voc("") # self.lang_map = {'as': '<2as>', 'bn': '<2bn>', 'en': '<2en>', 'gu': '<2gu>', 'hi': '<2hi>', 'kn': '<2kn>', 'ml': '<2ml>', 'mr': '<2mr>', 'or': '<2or>', 'pa': '<2pa>', 'ta': '<2ta>', 'te': '<2te>'} # print("Headline Generation model loaded...!") # def get_headline(self, text, lang_id): # inp = self.tokenizer(text, add_special_tokens=False, return_tensors="pt", padding=True).to(self.device) # inp = inp['input_ids'] # lang_code = self.lang_map.get(lang_id, '') # text = text + " " + lang_code # # print("Text: ", text) # model_output = self.model.generate( # inp, # use_cache=True, # num_beams=5, # max_length=32, # min_length=1, # early_stopping=True, # pad_token_id = self.pad_id, # bos_token_id = self.bos_id, # eos_token_id = self.eos_id, # decoder_start_token_id = self.tokenizer._convert_token_to_id_with_added_voc(lang_code) # ) # decoded_output = self.tokenizer.decode( # model_output[0], # skip_special_tokens=True, # clean_up_tokenization_spaces=False # ) # return decoded_output # class Summarization(): # def __init__(self, model_name = "ashokurlana/mBART-TeSum"): # self.model_name = model_name # self.device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") # self.tokenizer = AutoTokenizer.from_pretrained(model_name) # self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) # self.model.to(self.device) # self.model.eval() # self.lang_map = {'te': 'te_IN', 'en': 'en_XX'} # print("Summarization model loaded...!") # def get_summary(self, text, lang_id): # inp = self.tokenizer([text], add_special_tokens=False, return_tensors="pt", max_length = 1024).to(self.device) # inp = inp['input_ids'] # lang_code = self.lang_map.get(lang_id, '') # model_output = self.model.generate( # inp, # use_cache=True, # num_beams=5, # max_length=256, # early_stopping=True # ) # decoded_output = [self.tokenizer.decode( # summ_id, # skip_special_tokens=True, # clean_up_tokenization_spaces=False # ) for summ_id in model_output] # return " ".join(decoded_output) def get_prediction(text, src_lang_id, tgt_lang_id, translate = False): # if len(sys.argv)<3: # print("Usage: python app.py ") # print("Text file should contain the article news") # exit() # txt_path = sys.argv[1] # lang_id = sys.argv[2] # if not os.path.exists(txt_path): # print("Path: {} do not exists".format(txt_path)) # exit() # text = '' # with open(txt_path, 'r', encoding='utf-8') as fp: # text = fp.read().strip() ### Login to huggingface token # access_token = "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" # login(token = access_token) # headline_generator = Headline_Generation() # summarizer = Summarization() # if translate == True: # translator = LTRC_Translation_API(tgt_lang = lang_id) # text = translator.translate(text) # headline = headline_generator.get_headline(text, lang_id) # summary = summarizer.get_summary(text, lang_id) # print("Article: ", text) # print("Summary: ", summary) # print("Headline: ", headline) # return "Headline: " + headline + "\nSummary: " + summary # return [text, summary, headline] translator = LTRC_Translation_API(src_lang = src_lang_id, tgt_lang = tgt_lang_id) text = translator.translate(text) return text interface = gr.Interface( get_prediction, inputs=[ gr.Textbox(lines = 8, label = "Source Text", info = "Provide the news article text here"), # gr.Textbox(lines = 8, label = "News Article Text", info = "Provide the news article text here. Check the `Translate` if the source language is english."), # gr.Dropdown( # ['as', 'bn', 'en', 'gu', 'hi', 'kn', 'ml', 'mr', 'or', 'pa', 'ta', 'te'], label="Source Language code", info="select the source language code" # ), # gr.Dropdown( # ['as', 'bn', 'en', 'gu', 'hi', 'kn', 'ml', 'mr', 'or', 'pa', 'ta', 'te'], label="Target Language code", info="select the target language code" # ), gr.Dropdown( ['en', 'hi', 'kn', 'ml', 'mr', 'ta', 'te'], label="Source Language code", info="select the source language code" ), gr.Dropdown( ['en', 'hi', 'kn', 'ml', 'mr', 'ta', 'te'], label="Target Language code", info="select the target language code" ), # gr.Checkbox(label="Translate", info="Is translation required?") ], outputs=[ gr.Textbox(lines = 8, label = "Translation", info = "Translated text"), # gr.Textbox(lines = 8, label = "Source Article Text", info = "Source article text (if `Translate` is enabled then the source will be translated to target language)"), # gr.Textbox(lines = 4, label = "Summary", info = "Summary of the given article (translated if `Translate` is enabled)"), # gr.Textbox(lines = 2, label = "Headline", info = "Generated headline of the given article (translated if `Translate` is enabled)") ] ) interface.launch(share=True)