Abhishek-D7
commited on
Commit
•
00a9d47
1
Parent(s):
193ff0e
Update app.py
Browse files
app.py
CHANGED
@@ -1,11 +1,11 @@
|
|
1 |
import streamlit as st
|
2 |
#from transformers import T5Tokenizer, T5ForConditionalGeneration, MarianMTModel, MarianTokenizer
|
3 |
import sentencepiece
|
4 |
-
from transformers import
|
5 |
|
6 |
def load_summarization_model():
|
7 |
model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')
|
8 |
-
tokenizer =
|
9 |
return model, tokenizer
|
10 |
|
11 |
summarization_model, summarization_tokenizer = load_summarization_model()
|
@@ -30,7 +30,7 @@ translation_models, translation_tokenizers = load_translation_models()
|
|
30 |
|
31 |
def summarize_text(article):
|
32 |
inputs = summarization_tokenizer.encode("summarize: " + article, return_tensors="pt", max_length=6000, truncation=True)
|
33 |
-
summary_ids = summarization_model.generate(inputs, max_length=
|
34 |
return summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
35 |
|
36 |
def translate_text(text, source_lang, target_lang):
|
|
|
1 |
import streamlit as st
|
2 |
#from transformers import T5Tokenizer, T5ForConditionalGeneration, MarianMTModel, MarianTokenizer
|
3 |
import sentencepiece
|
4 |
+
from transformers import AutoTokenizer, BartForConditionalGeneration, MarianMTModel, MarianTokenizer
|
5 |
|
6 |
def load_summarization_model():
|
7 |
model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')
|
8 |
+
tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-cnn')
|
9 |
return model, tokenizer
|
10 |
|
11 |
summarization_model, summarization_tokenizer = load_summarization_model()
|
|
|
30 |
|
31 |
def summarize_text(article):
|
32 |
inputs = summarization_tokenizer.encode("summarize: " + article, return_tensors="pt", max_length=6000, truncation=True)
|
33 |
+
summary_ids = summarization_model.generate(inputs, max_length=256, min_length=10, length_penalty=2.0, num_beams=4, early_stopping=True)
|
34 |
return summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
35 |
|
36 |
def translate_text(text, source_lang, target_lang):
|