Abhishek-D7's picture
Update app.py
e31cf98 verified
raw
history blame
No virus
4.35 kB
import streamlit as st
from transformers import AutoTokenizer, BartForConditionalGeneration, MarianMTModel, MarianTokenizer
def load_summarization_model():
model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')
tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-cnn')
return model, tokenizer
def load_translation_models():
models = {}
tokenizers = {}
language_pairs = {
'en-fr': 'Helsinki-NLP/opus-mt-en-fr',
'fr-en': 'Helsinki-NLP/opus-mt-fr-en',
'en-de': 'Helsinki-NLP/opus-mt-en-de',
'de-en': 'Helsinki-NLP/opus-mt-de-en',
'en-hi': 'Helsinki-NLP/opus-mt-en-hi',
'hi-en': 'Helsinki-NLP/opus-mt-hi-en'
}
for pair, model_name in language_pairs.items():
models[pair] = MarianMTModel.from_pretrained(model_name)
tokenizers[pair] = MarianTokenizer.from_pretrained(model_name)
return models, tokenizers
summarization_model, summarization_tokenizer = load_summarization_model()
translation_models, translation_tokenizers = load_translation_models()
def summarize_text(article):
inputs = summarization_tokenizer.encode("summarize: " + article, return_tensors="pt", max_length=1024, truncation=True)
summary_ids = summarization_model.generate(inputs, max_length=256, min_length=50, length_penalty=2.0, num_beams=4, early_stopping=True)
return summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
def translate_text(text, source_lang, target_lang):
if source_lang == target_lang:
return text
language_pair = f'{source_lang}-{target_lang}'
model = translation_models[language_pair]
tokenizer = translation_tokenizers[language_pair]
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
translated = model.generate(**inputs)
return tokenizer.decode(translated[0], skip_special_tokens=True)
def main():
st.markdown(
'''
<style>
h1 {
color: #592f2f;
font-size: 3em;
font-weight: bold;
}
h2 {
color: #f28e2c;
font-size: 2em;
}
h3 {
color: #e15759;
font-size: 1.5em;
}
.stButton>button {
background-color: #17499c;
color: white;
font-size: 1em;
border-radius: 8px;
}
.stButton>button:hover {
background-color: #0a2b61;
color: white;
}
</style>
''',
unsafe_allow_html=True
)
st.title("Multilingual Text :blue[Summarizer] and :blue[Translator] :sunglasses:")
task = st.selectbox('Choose a task', ('Summarize', 'Translate'))
if 'summary' not in st.session_state:
st.session_state.summary = None
if task == 'Summarize':
st.header("Summarize Text")
article = st.text_area("Enter the text here:")
if st.button("Generate Summary"):
st.session_state.summary = summarize_text(article)
st.subheader("Summary:")
st.write(st.session_state.summary)
if st.session_state.summary:
st.header("Translate Summary")
source_lang = st.selectbox('Source Language', ('en', 'fr', 'de', 'hi'), key='source_lang_summary')
target_lang = st.selectbox('Translate Summary into', ('en', 'fr', 'de', 'hi'), key='target_lang_summary')
if st.button("Translate Summary"):
translated_summary = translate_text(st.session_state.summary, source_lang, target_lang)
st.subheader("Translated Summary:")
st.write(translated_summary)
else:
st.warning("Please generate a summary first.")
elif task == 'Translate':
st.header("Translate Text")
article = st.text_area("Enter the text here:")
source_lang = st.selectbox('Source Language', ('en', 'fr', 'de', 'hi'), key='source_lang_article')
target_lang = st.selectbox('Translate into', ('en', 'fr', 'de', 'hi'), key='target_lang_article')
if st.button("Translate Article"):
translated_article = translate_text(article, source_lang, target_lang)
st.subheader("Translated Article:")
st.write(translated_article)
if __name__ == "__main__":
main()