Japanese_NLP / app.py
shubh2014shiv's picture
corrected path references
8144a1f
raw history blame
No virus
12.6 kB
import streamlit as st
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from st_aggrid import AgGrid
from st_aggrid.grid_options_builder import GridOptionsBuilder
from st_aggrid.shared import JsCode
from st_aggrid.shared import GridUpdateMode
from transformers import T5Tokenizer, BertForSequenceClassification
import torch
import numpy as np
st.set_page_config(layout="wide")
st.title("Project - Japanese Natural Language Processing (自然言語処理) using Transformers")
st.sidebar.subheader("自然言語処理 トピック")
topic = st.sidebar.radio(label="Select the NLP project topics", options=["Sentiment Analysis","Text Summarization"])
st.write("-" * 5)
jp_review_text = None
#JAPANESE_SENTIMENT_PROJECT_PATH = './Japanese Amazon reviews sentiments/'
if topic == "Sentiment Analysis":
st.markdown(
"<h2 style='text-align: left; color:#EE82EE; font-size:25px;'><b>Transfer Learning based Japanese Sentiments Analysis using BERT<b></h2>",
unsafe_allow_html=True)
st.markdown(
"<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Japanese Amazon Reviews Data (日本のAmazonレビューデータ)<b></h3>",
unsafe_allow_html=True)
amazon_jp_reviews = pd.read_csv("review_val.csv").sample(frac=1,random_state=10).iloc[:16000]
cellstyle_jscode = JsCode(
"""
function(params) {
if (params.value.includes('positive')) {
return {
'color': 'black',
'backgroundColor': '#32CD32'
}
} else {
return {
'color': 'black',
'backgroundColor': '#FF7F7F'
}
}
};
"""
)
st.write('<style>div.row-widget.stRadio > div{flex-direction:row;justify-content: center;} </style>',
unsafe_allow_html=True)
st.write('<style>div.st-bf{flex-direction:column;} div.st-ag{font-weight:bold;padding-left:2px;}</style>',
unsafe_allow_html=True)
choose = st.radio("", ("Choose a review from the dataframe below", "Manually write review"))
SELECT_ONE_REVIEW = "Choose a review from the dataframe below"
WRITE_REVIEW = "Manually write review"
gb = GridOptionsBuilder.from_dataframe(amazon_jp_reviews)
gb.configure_column("sentiment", cellStyle=cellstyle_jscode)
gb.configure_pagination()
if choose == SELECT_ONE_REVIEW:
gb.configure_selection(selection_mode="single", use_checkbox=True, suppressRowDeselection=False)
gridOptions = gb.build()
if choose == SELECT_ONE_REVIEW:
jp_review_choice = AgGrid(amazon_jp_reviews, gridOptions=gridOptions, theme='material',
enable_enterprise_modules=True,
allow_unsafe_jscode=True, update_mode=GridUpdateMode.SELECTION_CHANGED)
st.info("Select any one the Japanese Reviews by clicking the checkbox. Reviews can be navigated from each page.")
if len(jp_review_choice['selected_rows']) != 0:
jp_review_text = jp_review_choice['selected_rows'][0]['review']
st.markdown(
"<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Selected Review in JSON (JSONで選択されたレビュー)<b></h3>",
unsafe_allow_html=True)
st.write(jp_review_choice['selected_rows'])
if choose == WRITE_REVIEW:
AgGrid(amazon_jp_reviews, gridOptions=gridOptions, theme='material',
enable_enterprise_modules=True,
allow_unsafe_jscode=True)
with open("test_reviews_jp.csv", "rb") as file:
st.download_button(label="Download Additional Japanese Reviews", data=file,
file_name="Additional Japanese Reviews.csv")
st.info("Additional subset of Japanese Reviews can be downloaded and any review can be copied & pasted in text area.")
sample_japanese_review_input = "子供のレッスンバッグ用に購入。 思ったより大きく、ピアノ教本を入れるには充分でした。中は汚れてました。 何より驚いたのは、商品の梱包。 2つ折は許せるが、透明ビニール袋の底思いっきり空いてますけど? 何これ?包むっていうか挟んで終わり?底が全開している。 引っ張れば誰でも中身の注文書も、商品も見れる状態って何なの? 個人情報が晒されて、商品も粗末な扱いで嫌な気持ちでした。 郵送で中身が無事のが奇跡じゃないでしょうか? ありえない"
jp_review_text = st.text_area(label="Press 'Ctrl+Enter' after writing review in below text area",
value=sample_japanese_review_input)
if len(jp_review_text) == 0:
st.error("Input text cannot empty. Either write the japanese review in text area manually or select the review from the grid.")
if jp_review_text:
st.markdown(
"<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Sentence-Piece based Japanese Tokenizer using RoBERTA<b></h3>",
unsafe_allow_html=True)
tokens_column, tokenID_column = st.columns(2)
tokenizer = T5Tokenizer.from_pretrained('rinna/japanese-roberta-base')
tokens = tokenizer.tokenize(jp_review_text)
token_ids = tokenizer.convert_tokens_to_ids(tokens)
with tokens_column:
token_expander = st.expander("Expand to see the tokens", expanded=False)
with token_expander:
st.write(tokens)
with tokenID_column:
tokenID_expander = st.expander("Expand to see the token IDs", expanded=False)
with tokenID_expander:
st.write(token_ids)
st.markdown(
"<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Encoded Japanese Review Text to get Input IDs and attention masks as PyTorch Tensor<b></h3>",
unsafe_allow_html=True)
encoded_data = tokenizer.batch_encode_plus(np.array([jp_review_text]).astype('object'),
add_special_tokens=True,
return_attention_mask=True,
padding=True,
max_length=200,
return_tensors='pt',
truncation=True)
input_ids = encoded_data['input_ids']
attention_masks = encoded_data['attention_mask']
input_ids_column, attention_masks_column = st.columns(2)
with input_ids_column:
input_ids_expander = st.expander("Expand to see the input IDs tensor")
with input_ids_expander:
st.write(input_ids)
with attention_masks_column:
attention_masks_expander = st.expander("Expand to see the attention mask tensor")
with attention_masks_expander:
st.write(attention_masks)
st.markdown(
"<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Predict Sentiment of review using Fine-Tuned Japanese BERT<b></h3>",
unsafe_allow_html=True)
label_dict = {'positive': 1, 'negative': 0}
if st.button("Predict Sentiment"):
with st.spinner("Wait.."):
predictions = []
model = BertForSequenceClassification.from_pretrained("shubh2014shiv/jp_review_sentiments_amzn",
num_labels=len(label_dict),
output_attentions=False,
output_hidden_states=False)
#model.load_state_dict(
# torch.load(JAPANESE_SENTIMENT_PROJECT_PATH + 'FineTuneJapaneseBert_AmazonReviewSentiments.pt',
# map_location=torch.device('cpu')))
model.load_state_dict(
torch.load('reviewSentiments_jp.pt',
map_location=torch.device('cpu')))
inputs = {
'input_ids': input_ids,
'attention_mask': attention_masks
}
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
logits = logits.detach().cpu().numpy()
scores = 1 / (1 + np.exp(-1 * logits))
result = {"TEXT (文章)": jp_review_text,'NEGATIVE (ネガティブ)': scores[0][0], 'POSITIVE (ポジティブ)': scores[0][1]}
result_col,graph_col = st.columns(2)
with result_col:
st.write(result)
with graph_col:
fig = px.bar(x=['NEGATIVE (ネガティブ)','POSITIVE (ポジティブ)'],y=[result['NEGATIVE (ネガティブ)'],result['POSITIVE (ポジティブ)']])
fig.update_layout(title="Probability distribution of Sentiment for the given text",\
yaxis_title="Probability (確率)")
fig.update_traces(marker_color=['#FF7F7F','#32CD32'])
st.plotly_chart(fig)
elif topic == "Text Summarization":
st.markdown(
"<h2 style='text-align: left; color:#EE82EE; font-size:25px;'><b>Summarizing Japanese News Article using multi-Lingual T5 (mT5)<b></h2>",
unsafe_allow_html=True)
st.markdown(
"<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Japanese News Article Data<b></h3>",
unsafe_allow_html=True)
news_articles = pd.read_csv("jp_news_articles.csv").sample(frac=0.75,
random_state=42)
gb = GridOptionsBuilder.from_dataframe(news_articles)
gb.configure_pagination()
gb.configure_selection(selection_mode="single", use_checkbox=True, suppressRowDeselection=False)
gridOptions = gb.build()
jp_article = AgGrid(news_articles, gridOptions=gridOptions, theme='material',
enable_enterprise_modules=True,
allow_unsafe_jscode=True, update_mode=GridUpdateMode.SELECTION_CHANGED)
# WHITESPACE_HANDLER = lambda k: re.sub('\s+', ' ', re.sub('\n+', ' ', k.strip()))
if len(jp_article['selected_rows']) == 0:
st.info("Pick any one Japanese News Article by selecting the checkbox. News articles can be navigated by clicking on page navigator at right-bottom")
else:
article_text = jp_article['selected_rows'][0]['News Articles']
text = st.text_area(label="Text from selected Japanese News Article(ニュース記事)", value=article_text, height=500)
summary_length = st.slider(label="Select the maximum length of summary (要約の最大長を選択します )", min_value=120,max_value=160,step=5)
if text and st.button("Summarize it! (要約しよう)"):
waitPlaceholder = st.image("wait.gif")
summarization_model_name = "csebuetnlp/mT5_multilingual_XLSum"
tokenizer = AutoTokenizer.from_pretrained(summarization_model_name )
model = AutoModelForSeq2SeqLM.from_pretrained(summarization_model_name )
input_ids = tokenizer(
article_text,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=512
)["input_ids"]
output_ids = model.generate(
input_ids=input_ids,
max_length=summary_length,
no_repeat_ngram_size=2,
num_beams=4
)[0]
summary = tokenizer.decode(
output_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)
waitPlaceholder.empty()
st.markdown(
"<h2 style='text-align: left; color:#32CD32; font-size:25px;'><b>Summary (要約文)<b></h2>",
unsafe_allow_html=True)
st.write(summary)