Abhishek-D7 commited on
Commit
00a9d47
1 Parent(s): 193ff0e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -1,11 +1,11 @@
1
  import streamlit as st
2
  #from transformers import T5Tokenizer, T5ForConditionalGeneration, MarianMTModel, MarianTokenizer
3
  import sentencepiece
4
- from transformers import BartTokenizer, BartForConditionalGeneration, MarianMTModel, MarianTokenizer
5
 
6
  def load_summarization_model():
7
  model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')
8
- tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
9
  return model, tokenizer
10
 
11
  summarization_model, summarization_tokenizer = load_summarization_model()
@@ -30,7 +30,7 @@ translation_models, translation_tokenizers = load_translation_models()
30
 
31
  def summarize_text(article):
32
  inputs = summarization_tokenizer.encode("summarize: " + article, return_tensors="pt", max_length=6000, truncation=True)
33
- summary_ids = summarization_model.generate(inputs, max_length=64, min_length=10, length_penalty=2.0, num_beams=4, early_stopping=True)
34
  return summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
35
 
36
  def translate_text(text, source_lang, target_lang):
 
1
  import streamlit as st
2
  #from transformers import T5Tokenizer, T5ForConditionalGeneration, MarianMTModel, MarianTokenizer
3
  import sentencepiece
4
+ from transformers import AutoTokenizer, BartForConditionalGeneration, MarianMTModel, MarianTokenizer
5
 
6
  def load_summarization_model():
7
  model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')
8
+ tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-cnn')
9
  return model, tokenizer
10
 
11
  summarization_model, summarization_tokenizer = load_summarization_model()
 
30
 
31
  def summarize_text(article):
32
  inputs = summarization_tokenizer.encode("summarize: " + article, return_tensors="pt", max_length=6000, truncation=True)
33
+ summary_ids = summarization_model.generate(inputs, max_length=256, min_length=10, length_penalty=2.0, num_beams=4, early_stopping=True)
34
  return summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
35
 
36
  def translate_text(text, source_lang, target_lang):