tareesh commited on
Commit
3a74906
1 Parent(s): ce6c5dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -10
app.py CHANGED
@@ -2,19 +2,21 @@ import streamlit as st
2
  from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
3
  from transformers import AutoModelForMaskedLM, AutoTokenizer
4
 
5
- model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-one-to-many-mmt")
6
- model = AutoModelForMaskedLM.from_pretrained("alabnii/jmedroberta-base-sentencepiece")
7
- model.eval()
8
- tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-one-to-many-mmt", src_large="en_XX")
9
- tokenizer = AutoTokenizer.from_pretrained("alabnii/jmedroberta-base-sentencepiece")
 
 
10
 
11
  text = st.text_area('Enter the text:')
12
 
13
  if text:
14
- model_inputs = tokenizer(text, return_tensors="pt")
15
- generated_tokens = model.generate(
16
  **model_inputs,
17
- forced_bos_token_id=tokenizer.lang_code_to_id["hi_IN"]
18
  )
19
- translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
20
- st.json(translation)
 
2
  from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
3
  from transformers import AutoModelForMaskedLM, AutoTokenizer
4
 
5
+ # Load the models and tokenizers
6
+ model_translation = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-one-to-many-mmt")
7
+ model_masked_lm = AutoModelForMaskedLM.from_pretrained("alabnii/jmedroberta-base-sentencepiece")
8
+ model_translation.eval()
9
+
10
+ tokenizer_translation = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-one-to-many-mmt", src_lang="en_XX")
11
+ tokenizer_masked_lm = AutoTokenizer.from_pretrained("alabnii/jmedroberta-base-sentencepiece")
12
 
13
  text = st.text_area('Enter the text:')
14
 
15
  if text:
16
+ model_inputs = tokenizer_translation(text, return_tensors="pt")
17
+ generated_tokens = model_translation.generate(
18
  **model_inputs,
19
+ forced_bos_token_id=tokenizer_translation.lang_code_to_id["hi_IN"]
20
  )
21
+ translation = tokenizer_translation.batch_decode(generated_tokens, skip_special_tokens=True)
22
+ st.json(translation)