Spaces:
Sleeping
Sleeping
import os, sys | |
# from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, MBartForConditionalGeneration | |
# import torch | |
import gradio as gr | |
import requests | |
import json | |
# from huggingface_hub import login | |
class LTRC_Translation_API(): | |
def __init__(self, url = 'https://ssmt.iiit.ac.in/onemt', src_lang = 'en', tgt_lang = 'te'): | |
self.lang_map = {'te': 'tel', 'en': 'eng', 'ta': 'tam', 'ml': 'mal', 'mr': 'mar', 'kn': 'kan', 'hi': 'hin'} | |
self.url = url | |
self.headers = { | |
'Content-Type': 'application/json', | |
'Accept': 'application/json' | |
} | |
tgt_lang = self.lang_map.get(tgt_lang, 'te') | |
src_lang = self.lang_map.get(src_lang, 'en') | |
self.src_lang = src_lang | |
self.tgt_lang = tgt_lang | |
def translate(self, text): | |
try: | |
data = {'text': text, 'source_language': self.src_lang, 'target_language': self.tgt_lang} | |
response = requests.post(self.url, headers = self.headers, json = data) | |
translated_text = json.loads(response.text).get('data', '') | |
return translated_text | |
except Exception as e: | |
print("Exception: ", e) | |
return '' | |
# class Headline_Generation(): | |
# def __init__(self, model_name = "lokeshmadasu42/sample"): | |
# self.model_name = model_name | |
# self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
# self.tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False, use_fast=False, keep_accents=True) | |
# self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
# self.model.to(self.device) | |
# self.model.eval() | |
# self.bos_id = self.tokenizer._convert_token_to_id_with_added_voc("<s>") | |
# self.eos_id = self.tokenizer._convert_token_to_id_with_added_voc("</s>") | |
# self.pad_id = self.tokenizer._convert_token_to_id_with_added_voc("<pad>") | |
# self.lang_map = {'as': '<2as>', 'bn': '<2bn>', 'en': '<2en>', 'gu': '<2gu>', 'hi': '<2hi>', 'kn': '<2kn>', 'ml': '<2ml>', 'mr': '<2mr>', 'or': '<2or>', 'pa': '<2pa>', 'ta': '<2ta>', 'te': '<2te>'} | |
# print("Headline Generation model loaded...!") | |
# def get_headline(self, text, lang_id): | |
# inp = self.tokenizer(text, add_special_tokens=False, return_tensors="pt", padding=True).to(self.device) | |
# inp = inp['input_ids'] | |
# lang_code = self.lang_map.get(lang_id, '') | |
# text = text + "</s> " + lang_code | |
# # print("Text: ", text) | |
# model_output = self.model.generate( | |
# inp, | |
# use_cache=True, | |
# num_beams=5, | |
# max_length=32, | |
# min_length=1, | |
# early_stopping=True, | |
# pad_token_id = self.pad_id, | |
# bos_token_id = self.bos_id, | |
# eos_token_id = self.eos_id, | |
# decoder_start_token_id = self.tokenizer._convert_token_to_id_with_added_voc(lang_code) | |
# ) | |
# decoded_output = self.tokenizer.decode( | |
# model_output[0], | |
# skip_special_tokens=True, | |
# clean_up_tokenization_spaces=False | |
# ) | |
# return decoded_output | |
# class Summarization(): | |
# def __init__(self, model_name = "ashokurlana/mBART-TeSum"): | |
# self.model_name = model_name | |
# self.device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") | |
# self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
# self.model.to(self.device) | |
# self.model.eval() | |
# self.lang_map = {'te': 'te_IN', 'en': 'en_XX'} | |
# print("Summarization model loaded...!") | |
# def get_summary(self, text, lang_id): | |
# inp = self.tokenizer([text], add_special_tokens=False, return_tensors="pt", max_length = 1024).to(self.device) | |
# inp = inp['input_ids'] | |
# lang_code = self.lang_map.get(lang_id, '') | |
# model_output = self.model.generate( | |
# inp, | |
# use_cache=True, | |
# num_beams=5, | |
# max_length=256, | |
# early_stopping=True | |
# ) | |
# decoded_output = [self.tokenizer.decode( | |
# summ_id, | |
# skip_special_tokens=True, | |
# clean_up_tokenization_spaces=False | |
# ) for summ_id in model_output] | |
# return " ".join(decoded_output) | |
def get_prediction(text, src_lang_id, tgt_lang_id, translate = False): | |
# if len(sys.argv)<3: | |
# print("Usage: python app.py <text_file_path> <lang_id>") | |
# print("Text file should contain the article news") | |
# exit() | |
# txt_path = sys.argv[1] | |
# lang_id = sys.argv[2] | |
# if not os.path.exists(txt_path): | |
# print("Path: {} do not exists".format(txt_path)) | |
# exit() | |
# text = '' | |
# with open(txt_path, 'r', encoding='utf-8') as fp: | |
# text = fp.read().strip() | |
### Login to huggingface token | |
# access_token = "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" | |
# login(token = access_token) | |
# headline_generator = Headline_Generation() | |
# summarizer = Summarization() | |
# if translate == True: | |
# translator = LTRC_Translation_API(tgt_lang = lang_id) | |
# text = translator.translate(text) | |
# headline = headline_generator.get_headline(text, lang_id) | |
# summary = summarizer.get_summary(text, lang_id) | |
# print("Article: ", text) | |
# print("Summary: ", summary) | |
# print("Headline: ", headline) | |
# return "Headline: " + headline + "\nSummary: " + summary | |
# return [text, summary, headline] | |
translator = LTRC_Translation_API(src_lang = src_lang_id, tgt_lang = tgt_lang_id) | |
text = translator.translate(text) | |
return text | |
interface = gr.Interface( | |
get_prediction, | |
inputs=[ | |
gr.Textbox(lines = 8, label = "Source Text", info = "Provide the news article text here"), | |
# gr.Textbox(lines = 8, label = "News Article Text", info = "Provide the news article text here. Check the `Translate` if the source language is english."), | |
# gr.Dropdown( | |
# ['as', 'bn', 'en', 'gu', 'hi', 'kn', 'ml', 'mr', 'or', 'pa', 'ta', 'te'], label="Source Language code", info="select the source language code" | |
# ), | |
# gr.Dropdown( | |
# ['as', 'bn', 'en', 'gu', 'hi', 'kn', 'ml', 'mr', 'or', 'pa', 'ta', 'te'], label="Target Language code", info="select the target language code" | |
# ), | |
gr.Dropdown( | |
['en', 'hi', 'kn', 'ml', 'mr', 'ta', 'te'], label="Source Language code", info="select the source language code" | |
), | |
gr.Dropdown( | |
['en', 'hi', 'kn', 'ml', 'mr', 'ta', 'te'], label="Target Language code", info="select the target language code" | |
), | |
# gr.Checkbox(label="Translate", info="Is translation required?") | |
], | |
outputs=[ | |
gr.Textbox(lines = 8, label = "Translation", info = "Translated text"), | |
# gr.Textbox(lines = 8, label = "Source Article Text", info = "Source article text (if `Translate` is enabled then the source will be translated to target language)"), | |
# gr.Textbox(lines = 4, label = "Summary", info = "Summary of the given article (translated if `Translate` is enabled)"), | |
# gr.Textbox(lines = 2, label = "Headline", info = "Generated headline of the given article (translated if `Translate` is enabled)") | |
], | |
title = "Indic Translation Demo" | |
) | |
interface.launch(share=True) |