shubh2014shiv commited on
Commit
e9cddb1
1 Parent(s): 98af3bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -1
app.py CHANGED
@@ -9,11 +9,49 @@ from st_aggrid.shared import GridUpdateMode
9
  from transformers import T5Tokenizer, BertForSequenceClassification,AutoTokenizer, AutoModelForSeq2SeqLM
10
  import torch
11
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  st.set_page_config(layout="wide")
14
  st.title("Project - Japanese Natural Language Processing (自然言語処理) using Transformers")
15
  st.sidebar.subheader("自然言語処理 トピック")
16
- topic = st.sidebar.radio(label="Select the NLP project topics", options=["Sentiment Analysis","Text Summarization"])
17
 
18
  st.write("-" * 5)
19
  jp_review_text = None
@@ -235,3 +273,89 @@ elif topic == "Text Summarization":
235
  unsafe_allow_html=True)
236
 
237
  st.write(summary)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  from transformers import T5Tokenizer, BertForSequenceClassification,AutoTokenizer, AutoModelForSeq2SeqLM
10
  import torch
11
  import numpy as np
12
+ import json
13
+ from transformers import AutoTokenizer, BertTokenizer, AutoModelWithLMHead
14
+ import pytorch_lightning as pl
15
+ from pathlib import Path
16
+
17
+ # Defining some functions for caching purpose by streamlit
18
+ class TranslationModel(pl.LightningModule):
19
+ def __init__(self):
20
+ super().__init__()
21
+ self.model = AutoModelWithLMHead.from_pretrained("Helsinki-NLP/opus-mt-ja-en", return_dict=True)
22
+
23
+
24
+ @st.experimental_singleton
25
+ def loadFineTunedJaEn_NMT_Model():
26
+ save_dest = Path('model')
27
+ save_dest.mkdir(exist_ok=True)
28
+
29
+ f_checkpoint = Path("model/best-checkpoint.ckpt")
30
+
31
+ if not f_checkpoint.exists():
32
+ with st.spinner("Downloading model.This may take a while! \n Don't refresh or close this page!"):
33
+ from GD_download import download_file_from_google_drive
34
+ download_file_from_google_drive('1CZQKGj9hSqj7kEuJp_jm7bNVXrbcFsgP', f_checkpoint)
35
+
36
+ trained_model = TranslationModel.load_from_checkpoint(f_checkpoint)
37
+
38
+ return trained_model
39
+
40
+ @st.experimental_singleton
41
+ def getJpEn_Tokenizers():
42
+ try:
43
+ with st.spinner("Downloading English and Japanese Transformer Tokenizers"):
44
+ ja_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ja-en")
45
+ en_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
46
+ except:
47
+ st.error("Issue with downloading tokenizers")
48
+
49
+ return ja_tokenizer, en_tokenizer
50
 
51
  st.set_page_config(layout="wide")
52
  st.title("Project - Japanese Natural Language Processing (自然言語処理) using Transformers")
53
  st.sidebar.subheader("自然言語処理 トピック")
54
+ topic = st.sidebar.radio(label="Select the NLP project topics", options=["Sentiment Analysis","Text Summarization","Japanese to English Translation"])
55
 
56
  st.write("-" * 5)
57
  jp_review_text = None
 
273
  unsafe_allow_html=True)
274
 
275
  st.write(summary)
276
+ elif topic == "Japanese to English Translation":
277
+ st.markdown(
278
+ "<h2 style='text-align: left; color:#EE82EE; font-size:25px;'><b>Japanese to English translation (for short sentences)<b></h2>",
279
+ unsafe_allow_html=True)
280
+ st.markdown(
281
+ "<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Business Scene Dialog Japanese-English Corpus<b></h3>",
282
+ unsafe_allow_html=True)
283
+
284
+ st.write("Below given Japanese-English pair is from 'Business Scene Dialog Corpus' by the University of Tokyo")
285
+ link = '[Corpus GitHub Link](https://github.com/tsuruoka-lab/BSD)'
286
+ st.markdown(link, unsafe_allow_html=True)
287
+
288
+ bsd_more_info = st.expander(label="Expand to get more information on data and training report")
289
+ with bsd_more_info:
290
+ st.markdown(
291
+ "<h3 style='text-align: left; color:#F63366; font-size:12px;'><b>Training Dataset<b></h3>",
292
+ unsafe_allow_html=True)
293
+ st.write("The corpus has total 20,000 Japanese-English Business Dialog pairs. The fined-tuned Transformer model is validated on 670 Japanese-English Business Dialog pairs")
294
+
295
+ st.markdown(
296
+ "<h3 style='text-align: left; color:#F63366; font-size:12px;'><b>Training Report<b></h3>",
297
+ unsafe_allow_html=True)
298
+ st.write(
299
+ "The Dashboard for training result on Tensorboard is [here](https://tensorboard.dev/experiment/eWhxt1i2RuaU64krYtORhw/)")
300
+
301
+ with open("./BSD_ja-en_val.json", encoding='utf-8') as f:
302
+ bsd_sample_data = json.load(f)
303
+
304
+ en, ja = [], []
305
+ for i in range(len(bsd_sample_data)):
306
+ for j in range(len(bsd_sample_data[i]['conversation'])):
307
+ en.append(bsd_sample_data[i]['conversation'][j]['en_sentence'])
308
+ ja.append(bsd_sample_data[i]['conversation'][j]['ja_sentence'])
309
+
310
+ df = pd.DataFrame.from_dict({'Japanese': ja, 'English': en})
311
+ gb = GridOptionsBuilder.from_dataframe(df)
312
+ gb.configure_pagination()
313
+ gb.configure_selection(selection_mode="single", use_checkbox=True, suppressRowDeselection=False)
314
+ gridOptions = gb.build()
315
+ translation_text = AgGrid(df, gridOptions=gridOptions, theme='material',
316
+ enable_enterprise_modules=True,
317
+ allow_unsafe_jscode=True, update_mode=GridUpdateMode.SELECTION_CHANGED)
318
+ if len(translation_text['selected_rows']) != 0:
319
+ bsd_jp = translation_text['selected_rows'][0]['Japanese']
320
+ st.markdown(
321
+ "<h2 style='text-align: left; color:#32CD32; font-size:25px;'><b>Business Scene Dialog in Japanese (日本語でのビジネスシーンダイアログ)<b></h2>",
322
+ unsafe_allow_html=True)
323
+ st.write(bsd_jp)
324
+
325
+ if st.button("Translate"):
326
+ ja_tokenizer, en_tokenizer = getJpEn_Tokenizers()
327
+ trained_model = loadFineTunedJaEn_NMT_Model()
328
+ trained_model.freeze()
329
+
330
+
331
+ def translate(text):
332
+ text_encoding = ja_tokenizer(
333
+ text,
334
+ max_length=100,
335
+ padding="max_length",
336
+ truncation=True,
337
+ return_attention_mask=True,
338
+ add_special_tokens=True,
339
+ return_tensors='pt'
340
+ )
341
+
342
+ generated_ids = trained_model.model.generate(
343
+ input_ids=text_encoding['input_ids'],
344
+ attention_mask=text_encoding['attention_mask'],
345
+ max_length=100,
346
+ num_beams=2,
347
+ repetition_penalty=2.5,
348
+ length_penalty=1.0,
349
+ early_stopping=True
350
+ )
351
+
352
+ preds = [en_tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True) for
353
+ gen_id in generated_ids]
354
+
355
+ return "".join(preds)[5:]
356
+
357
+
358
+ st.markdown(
359
+ "<h2 style='text-align: left; color:#32CD32; font-size:25px;'><b>Translated Dialog in English (英語の翻訳されたダイアログ)<b></h2>",
360
+ unsafe_allow_html=True)
361
+ st.write(translate(bsd_jp))