|
import streamlit as st |
|
import pandas as pd |
|
import plotly.express as px |
|
import plotly.graph_objects as go |
|
from st_aggrid import AgGrid |
|
from st_aggrid.grid_options_builder import GridOptionsBuilder |
|
from st_aggrid.shared import JsCode |
|
from st_aggrid.shared import GridUpdateMode |
|
from transformers import T5Tokenizer, BertForSequenceClassification,AutoTokenizer, AutoModelForSeq2SeqLM |
|
import torch |
|
import numpy as np |
|
import json |
|
from transformers import AutoTokenizer, BertTokenizer, AutoModelWithLMHead |
|
import pytorch_lightning as pl |
|
from pathlib import Path |
|
|
|
|
|
class TranslationModel(pl.LightningModule): |
|
def __init__(self): |
|
super().__init__() |
|
self.model = AutoModelWithLMHead.from_pretrained("Helsinki-NLP/opus-mt-ja-en", return_dict=True) |
|
|
|
|
|
|
|
def loadFineTunedJaEn_NMT_Model(): |
|
''' |
|
save_dest = Path('model') |
|
save_dest.mkdir(exist_ok=True) |
|
st.write("Creating new folder for downloading the Japanese to English Translation Model. ") |
|
f_checkpoint = Path("model/best-checkpoint.ckpt") |
|
st.write("'Folder: model/best-checkpoint.ckpt' created.") |
|
if not f_checkpoint.exists(): |
|
with st.spinner("Downloading model.This may take a while! \n Don't refresh or close this page!"): |
|
from GD_download import download_file_from_google_drive |
|
download_file_from_google_drive('1CZQKGj9hSqj7kEuJp_jm7bNVXrbcFsgP', f_checkpoint) |
|
''' |
|
bsd_jp_to_eng_trained_model = TranslationModel.load_from_checkpoint(Path("business_dialogue_japanese_english_model_fine_tuned.ckpt")) |
|
|
|
|
|
return bsd_jp_to_eng_trained_model |
|
|
|
@st.experimental_singleton |
|
def getJpEn_Tokenizers(): |
|
try: |
|
with st.spinner("Downloading English and Japanese Transformer Tokenizers"): |
|
ja_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ja-en") |
|
en_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") |
|
except: |
|
st.error("Issue with downloading tokenizers") |
|
|
|
return ja_tokenizer, en_tokenizer |
|
|
|
st.set_page_config(layout="wide") |
|
st.title("Project - Japanese Natural Language Processing (自然言語処理) using Transformers") |
|
st.sidebar.subheader("自然言語処理 トピック") |
|
topic = st.sidebar.radio(label="Select the NLP project topics", options=["Sentiment Analysis","Text Summarization","Japanese to English Translation"]) |
|
|
|
st.write("-" * 5) |
|
jp_review_text = None |
|
|
|
|
|
if topic == "Sentiment Analysis": |
|
st.markdown( |
|
"<h2 style='text-align: left; color:#EE82EE; font-size:25px;'><b>Transfer Learning based Japanese Sentiments Analysis using BERT<b></h2>", |
|
unsafe_allow_html=True) |
|
st.markdown( |
|
"<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Japanese Amazon Reviews Data (日本のAmazonレビューデータ)<b></h3>", |
|
unsafe_allow_html=True) |
|
|
|
amazon_jp_reviews = pd.read_csv("review_val.csv").sample(frac=1,random_state=10).iloc[:16000] |
|
|
|
cellstyle_jscode = JsCode( |
|
""" |
|
function(params) { |
|
if (params.value.includes('positive')) { |
|
return { |
|
'color': 'black', |
|
'backgroundColor': '#32CD32' |
|
} |
|
} else { |
|
return { |
|
'color': 'black', |
|
'backgroundColor': '#FF7F7F' |
|
} |
|
} |
|
}; |
|
""" |
|
) |
|
st.write('<style>div.row-widget.stRadio > div{flex-direction:row;justify-content: center;} </style>', |
|
unsafe_allow_html=True) |
|
|
|
st.write('<style>div.st-bf{flex-direction:column;} div.st-ag{font-weight:bold;padding-left:2px;}</style>', |
|
unsafe_allow_html=True) |
|
|
|
choose = st.radio("", ("Choose a review from the dataframe below", "Manually write review")) |
|
|
|
SELECT_ONE_REVIEW = "Choose a review from the dataframe below" |
|
WRITE_REVIEW = "Manually write review" |
|
|
|
gb = GridOptionsBuilder.from_dataframe(amazon_jp_reviews) |
|
gb.configure_column("sentiment", cellStyle=cellstyle_jscode) |
|
gb.configure_pagination() |
|
if choose == SELECT_ONE_REVIEW: |
|
gb.configure_selection(selection_mode="single", use_checkbox=True, suppressRowDeselection=False) |
|
gridOptions = gb.build() |
|
|
|
if choose == SELECT_ONE_REVIEW: |
|
jp_review_choice = AgGrid(amazon_jp_reviews, gridOptions=gridOptions, theme='material', |
|
enable_enterprise_modules=True, |
|
allow_unsafe_jscode=True, update_mode=GridUpdateMode.SELECTION_CHANGED) |
|
st.info("Select any one the Japanese Reviews by clicking the checkbox. Reviews can be navigated from each page.") |
|
if len(jp_review_choice['selected_rows']) != 0: |
|
jp_review_text = jp_review_choice['selected_rows'][0]['review'] |
|
st.markdown( |
|
"<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Selected Review in JSON (JSONで選択されたレビュー)<b></h3>", |
|
unsafe_allow_html=True) |
|
st.write(jp_review_choice['selected_rows']) |
|
|
|
if choose == WRITE_REVIEW: |
|
|
|
AgGrid(amazon_jp_reviews, gridOptions=gridOptions, theme='material', |
|
enable_enterprise_modules=True, |
|
allow_unsafe_jscode=True) |
|
with open("test_reviews_jp.csv", "rb") as file: |
|
st.download_button(label="Download Additional Japanese Reviews", data=file, |
|
file_name="Additional Japanese Reviews.csv") |
|
st.info("Additional subset of Japanese Reviews can be downloaded and any review can be copied & pasted in text area.") |
|
sample_japanese_review_input = "子供のレッスンバッグ用に購入。 思ったより大きく、ピアノ教本を入れるには充分でした。中は汚れてました。 何より驚いたのは、商品の梱包。 2つ折は許せるが、透明ビニール袋の底思いっきり空いてますけど? 何これ?包むっていうか挟んで終わり?底が全開している。 引っ張れば誰でも中身の注文書も、商品も見れる状態って何なの? 個人情報が晒されて、商品も粗末な扱いで嫌な気持ちでした。 郵送で中身が無事のが奇跡じゃないでしょうか? ありえない" |
|
jp_review_text = st.text_area(label="Press 'Ctrl+Enter' after writing review in below text area", |
|
value=sample_japanese_review_input) |
|
if len(jp_review_text) == 0: |
|
st.error("Input text cannot empty. Either write the japanese review in text area manually or select the review from the grid.") |
|
|
|
if jp_review_text: |
|
st.markdown( |
|
"<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Sentence-Piece based Japanese Tokenizer using RoBERTA<b></h3>", |
|
unsafe_allow_html=True) |
|
tokens_column, tokenID_column = st.columns(2) |
|
tokenizer = T5Tokenizer.from_pretrained('rinna/japanese-roberta-base') |
|
tokens = tokenizer.tokenize(jp_review_text) |
|
token_ids = tokenizer.convert_tokens_to_ids(tokens) |
|
with tokens_column: |
|
token_expander = st.expander("Expand to see the tokens", expanded=False) |
|
with token_expander: |
|
st.write(tokens) |
|
with tokenID_column: |
|
tokenID_expander = st.expander("Expand to see the token IDs", expanded=False) |
|
with tokenID_expander: |
|
st.write(token_ids) |
|
|
|
st.markdown( |
|
"<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Encoded Japanese Review Text to get Input IDs and attention masks as PyTorch Tensor<b></h3>", |
|
unsafe_allow_html=True) |
|
encoded_data = tokenizer.batch_encode_plus(np.array([jp_review_text]).astype('object'), |
|
add_special_tokens=True, |
|
return_attention_mask=True, |
|
padding=True, |
|
max_length=200, |
|
return_tensors='pt', |
|
truncation=True) |
|
input_ids = encoded_data['input_ids'] |
|
attention_masks = encoded_data['attention_mask'] |
|
input_ids_column, attention_masks_column = st.columns(2) |
|
with input_ids_column: |
|
input_ids_expander = st.expander("Expand to see the input IDs tensor") |
|
with input_ids_expander: |
|
st.write(input_ids) |
|
with attention_masks_column: |
|
attention_masks_expander = st.expander("Expand to see the attention mask tensor") |
|
with attention_masks_expander: |
|
st.write(attention_masks) |
|
|
|
st.markdown( |
|
"<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Predict Sentiment of review using Fine-Tuned Japanese BERT<b></h3>", |
|
unsafe_allow_html=True) |
|
|
|
label_dict = {'positive': 1, 'negative': 0} |
|
if st.button("Predict Sentiment"): |
|
with st.spinner("Wait.."): |
|
predictions = [] |
|
model = BertForSequenceClassification.from_pretrained("shubh2014shiv/jp_review_sentiments_amzn", |
|
num_labels=len(label_dict), |
|
output_attentions=False, |
|
output_hidden_states=False) |
|
|
|
|
|
|
|
|
|
model.load_state_dict( |
|
torch.load('reviewSentiments_jp.pt', |
|
map_location=torch.device('cpu'))) |
|
|
|
inputs = { |
|
'input_ids': input_ids, |
|
'attention_mask': attention_masks |
|
} |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
logits = outputs.logits |
|
logits = logits.detach().cpu().numpy() |
|
scores = 1 / (1 + np.exp(-1 * logits)) |
|
|
|
result = {"TEXT (文章)": jp_review_text,'NEGATIVE (ネガティブ)': scores[0][0], 'POSITIVE (ポジティブ)': scores[0][1]} |
|
|
|
result_col,graph_col = st.columns(2) |
|
with result_col: |
|
st.write(result) |
|
with graph_col: |
|
fig = px.bar(x=['NEGATIVE (ネガティブ)','POSITIVE (ポジティブ)'],y=[result['NEGATIVE (ネガティブ)'],result['POSITIVE (ポジティブ)']]) |
|
fig.update_layout(title="Probability distribution of Sentiment for the given text",\ |
|
yaxis_title="Probability (確率)") |
|
fig.update_traces(marker_color=['#FF7F7F','#32CD32']) |
|
st.plotly_chart(fig) |
|
|
|
elif topic == "Text Summarization": |
|
st.markdown( |
|
"<h2 style='text-align: left; color:#EE82EE; font-size:25px;'><b>Summarizing Japanese News Article using multi-Lingual T5 (mT5)<b></h2>", |
|
unsafe_allow_html=True) |
|
st.markdown( |
|
"<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Japanese News Article Data<b></h3>", |
|
unsafe_allow_html=True) |
|
|
|
news_articles = pd.read_csv("jp_news_articles_val.csv").sample(frac=0.75, |
|
random_state=42) |
|
gb = GridOptionsBuilder.from_dataframe(news_articles) |
|
gb.configure_pagination() |
|
gb.configure_selection(selection_mode="single", use_checkbox=True, suppressRowDeselection=False) |
|
gridOptions = gb.build() |
|
jp_article = AgGrid(news_articles, gridOptions=gridOptions, theme='material', |
|
enable_enterprise_modules=True, |
|
allow_unsafe_jscode=True, update_mode=GridUpdateMode.SELECTION_CHANGED) |
|
|
|
|
|
if len(jp_article['selected_rows']) == 0: |
|
st.info("Pick any one Japanese News Article by selecting the checkbox. News articles can be navigated by clicking on page navigator at right-bottom") |
|
else: |
|
article_text = jp_article['selected_rows'][0]['News Articles'] |
|
|
|
text = st.text_area(label="Text from selected Japanese News Article(ニュース記事)", value=article_text, height=500) |
|
summary_length = st.slider(label="Select the maximum length of summary (要約の最大長を選択します )", min_value=120,max_value=160,step=5) |
|
|
|
if text and st.button("Summarize it! (要約しよう)"): |
|
waitPlaceholder = st.image("wait.gif") |
|
summarization_model_name = "csebuetnlp/mT5_multilingual_XLSum" |
|
tokenizer = AutoTokenizer.from_pretrained(summarization_model_name ) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(summarization_model_name ) |
|
|
|
input_ids = tokenizer( |
|
article_text, |
|
return_tensors="pt", |
|
padding="max_length", |
|
truncation=True, |
|
max_length=512 |
|
)["input_ids"] |
|
|
|
output_ids = model.generate( |
|
input_ids=input_ids, |
|
max_length=summary_length, |
|
no_repeat_ngram_size=2, |
|
num_beams=4 |
|
)[0] |
|
|
|
summary = tokenizer.decode( |
|
output_ids, |
|
skip_special_tokens=True, |
|
clean_up_tokenization_spaces=False |
|
) |
|
|
|
waitPlaceholder.empty() |
|
|
|
st.markdown( |
|
"<h2 style='text-align: left; color:#32CD32; font-size:25px;'><b>Summary (要約文)<b></h2>", |
|
unsafe_allow_html=True) |
|
|
|
st.write(summary) |
|
elif topic == "Japanese to English Translation": |
|
st.markdown( |
|
"<h2 style='text-align: left; color:#EE82EE; font-size:25px;'><b>Japanese to English translation (for short sentences)<b></h2>", |
|
unsafe_allow_html=True) |
|
st.markdown( |
|
"<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Business Scene Dialog Japanese-English Corpus<b></h3>", |
|
unsafe_allow_html=True) |
|
|
|
st.write("Below given Japanese-English pair is from 'Business Scene Dialog Corpus' by the University of Tokyo") |
|
link = '[Corpus GitHub Link](https://github.com/tsuruoka-lab/BSD)' |
|
st.markdown(link, unsafe_allow_html=True) |
|
|
|
bsd_more_info = st.expander(label="Expand to get more information on data and training report") |
|
with bsd_more_info: |
|
st.markdown( |
|
"<h3 style='text-align: left; color:#F63366; font-size:12px;'><b>Training Dataset<b></h3>", |
|
unsafe_allow_html=True) |
|
st.write("The corpus has total 20,000 Japanese-English Business Dialog pairs. The fined-tuned Transformer model is validated on 670 Japanese-English Business Dialog pairs") |
|
|
|
st.markdown( |
|
"<h3 style='text-align: left; color:#F63366; font-size:12px;'><b>Training Report<b></h3>", |
|
unsafe_allow_html=True) |
|
st.write( |
|
"The Dashboard for training result on Tensorboard is [here](https://tensorboard.dev/experiment/eWhxt1i2RuaU64krYtORhw/)") |
|
|
|
with open("./BSD_ja-en_val.json", encoding='utf-8') as f: |
|
bsd_sample_data = json.load(f) |
|
|
|
en, ja = [], [] |
|
for i in range(len(bsd_sample_data)): |
|
for j in range(len(bsd_sample_data[i]['conversation'])): |
|
en.append(bsd_sample_data[i]['conversation'][j]['en_sentence']) |
|
ja.append(bsd_sample_data[i]['conversation'][j]['ja_sentence']) |
|
|
|
df = pd.DataFrame.from_dict({'Japanese': ja, 'English': en}) |
|
gb = GridOptionsBuilder.from_dataframe(df) |
|
gb.configure_pagination() |
|
gb.configure_selection(selection_mode="single", use_checkbox=True, suppressRowDeselection=False) |
|
gridOptions = gb.build() |
|
translation_text = AgGrid(df, gridOptions=gridOptions, theme='material', |
|
enable_enterprise_modules=True, |
|
allow_unsafe_jscode=True, update_mode=GridUpdateMode.SELECTION_CHANGED) |
|
if len(translation_text['selected_rows']) != 0: |
|
bsd_jp = translation_text['selected_rows'][0]['Japanese'] |
|
st.markdown( |
|
"<h2 style='text-align: left; color:#32CD32; font-size:25px;'><b>Business Scene Dialog in Japanese (日本語でのビジネスシーンダイアログ)<b></h2>", |
|
unsafe_allow_html=True) |
|
st.write(bsd_jp) |
|
|
|
if st.button("Translate"): |
|
waitPlaceholder = st.image("wait.gif") |
|
ja_tokenizer, en_tokenizer = getJpEn_Tokenizers() |
|
trained_model = loadFineTunedJaEn_NMT_Model() |
|
trained_model.freeze() |
|
|
|
|
|
def translate(text): |
|
text_encoding = ja_tokenizer( |
|
text, |
|
max_length=100, |
|
padding="max_length", |
|
truncation=True, |
|
return_attention_mask=True, |
|
add_special_tokens=True, |
|
return_tensors='pt' |
|
) |
|
|
|
generated_ids = trained_model.model.generate( |
|
input_ids=text_encoding['input_ids'], |
|
attention_mask=text_encoding['attention_mask'], |
|
max_length=100, |
|
num_beams=2, |
|
repetition_penalty=2.5, |
|
length_penalty=1.0, |
|
early_stopping=True |
|
) |
|
|
|
preds = [en_tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True) for |
|
gen_id in generated_ids] |
|
|
|
return "".join(preds)[5:] |
|
waitPlaceholder.empty() |
|
|
|
st.markdown( |
|
"<h2 style='text-align: left; color:#32CD32; font-size:25px;'><b>Translated Dialog in English (英語の翻訳されたダイアログ)<b></h2>", |
|
unsafe_allow_html=True) |
|
st.write(translate(bsd_jp)) |
|
|