Spaces:
Runtime error
Runtime error
import streamlit as st | |
import pandas as pd | |
import plotly.express as px | |
import plotly.graph_objects as go | |
from st_aggrid import AgGrid | |
from st_aggrid.grid_options_builder import GridOptionsBuilder | |
from st_aggrid.shared import JsCode | |
from st_aggrid.shared import GridUpdateMode | |
from transformers import T5Tokenizer, BertForSequenceClassification | |
import torch | |
import numpy as np | |
st.set_page_config(layout="wide") | |
st.title("Project - Japanese Natural Language Processing (自然言語処理) using Transformers") | |
st.sidebar.subheader("自然言語処理 トピック") | |
topic = st.sidebar.radio(label="Select the NLP project topics", options=["Sentiment Analysis"]) | |
st.write("-" * 5) | |
jp_review_text = None | |
#JAPANESE_SENTIMENT_PROJECT_PATH = './Japanese Amazon reviews sentiments/' | |
if topic == "Sentiment Analysis": | |
st.markdown( | |
"<h2 style='text-align: left; color:#EE82EE; font-size:25px;'><b>Transfer Learning based Japanese Sentiments Analysis using BERT<b></h2>", | |
unsafe_allow_html=True) | |
st.markdown( | |
"<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Japanese Amazon Reviews Data (日本のAmazonレビューデータ)<b></h3>", | |
unsafe_allow_html=True) | |
amazon_jp_reviews = pd.read_csv("review_val.csv").sample(frac=1,random_state=10).iloc[:16000] | |
cellstyle_jscode = JsCode( | |
""" | |
function(params) { | |
if (params.value.includes('positive')) { | |
return { | |
'color': 'black', | |
'backgroundColor': '#32CD32' | |
} | |
} else { | |
return { | |
'color': 'black', | |
'backgroundColor': '#FF7F7F' | |
} | |
} | |
}; | |
""" | |
) | |
st.write('<style>div.row-widget.stRadio > div{flex-direction:row;justify-content: center;} </style>', | |
unsafe_allow_html=True) | |
st.write('<style>div.st-bf{flex-direction:column;} div.st-ag{font-weight:bold;padding-left:2px;}</style>', | |
unsafe_allow_html=True) | |
choose = st.radio("", ("Choose a review from the dataframe below", "Manually write review")) | |
SELECT_ONE_REVIEW = "Choose a review from the dataframe below" | |
WRITE_REVIEW = "Manually write review" | |
gb = GridOptionsBuilder.from_dataframe(amazon_jp_reviews) | |
gb.configure_column("sentiment", cellStyle=cellstyle_jscode) | |
gb.configure_pagination() | |
if choose == SELECT_ONE_REVIEW: | |
gb.configure_selection(selection_mode="single", use_checkbox=True, suppressRowDeselection=False) | |
gridOptions = gb.build() | |
if choose == SELECT_ONE_REVIEW: | |
jp_review_choice = AgGrid(amazon_jp_reviews, gridOptions=gridOptions, theme='material', | |
enable_enterprise_modules=True, | |
allow_unsafe_jscode=True, update_mode=GridUpdateMode.SELECTION_CHANGED) | |
st.info("Select any one the Japanese Reviews by clicking the checkbox. Reviews can be navigated from each page.") | |
if len(jp_review_choice['selected_rows']) != 0: | |
jp_review_text = jp_review_choice['selected_rows'][0]['review'] | |
st.markdown( | |
"<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Selected Review in JSON (JSONで選択されたレビュー)<b></h3>", | |
unsafe_allow_html=True) | |
st.write(jp_review_choice['selected_rows']) | |
if choose == WRITE_REVIEW: | |
AgGrid(amazon_jp_reviews, gridOptions=gridOptions, theme='material', | |
enable_enterprise_modules=True, | |
allow_unsafe_jscode=True) | |
with open("test_reviews_jp.csv", "rb") as file: | |
st.download_button(label="Download Additional Japanese Reviews", data=file, | |
file_name="Additional Japanese Reviews.csv") | |
st.info("Additional subset of Japanese Reviews can be downloaded and any review can be copied & pasted in text area.") | |
sample_japanese_review_input = "子供のレッスンバッグ用に購入。 思ったより大きく、ピアノ教本を入れるには充分でした。中は汚れてました。 何より驚いたのは、商品の梱包。 2つ折は許せるが、透明ビニール袋の底思いっきり空いてますけど? 何これ?包むっていうか挟んで終わり?底が全開している。 引っ張れば誰でも中身の注文書も、商品も見れる状態って何なの? 個人情報が晒されて、商品も粗末な扱いで嫌な気持ちでした。 郵送で中身が無事のが奇跡じゃないでしょうか? ありえない" | |
jp_review_text = st.text_area(label="Press 'Ctrl+Enter' after writing review in below text area", | |
value=sample_japanese_review_input) | |
if len(jp_review_text) == 0: | |
st.error("Input text cannot empty. Either write the japanese review in text area manually or select the review from the grid.") | |
if jp_review_text: | |
st.markdown( | |
"<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Sentence-Piece based Japanese Tokenizer using RoBERTA<b></h3>", | |
unsafe_allow_html=True) | |
tokens_column, tokenID_column = st.columns(2) | |
tokenizer = T5Tokenizer.from_pretrained('rinna/japanese-roberta-base') | |
tokens = tokenizer.tokenize(jp_review_text) | |
token_ids = tokenizer.convert_tokens_to_ids(tokens) | |
with tokens_column: | |
token_expander = st.expander("Expand to see the tokens", expanded=False) | |
with token_expander: | |
st.write(tokens) | |
with tokenID_column: | |
tokenID_expander = st.expander("Expand to see the token IDs", expanded=False) | |
with tokenID_expander: | |
st.write(token_ids) | |
st.markdown( | |
"<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Encoded Japanese Review Text to get Input IDs and attention masks as PyTorch Tensor<b></h3>", | |
unsafe_allow_html=True) | |
encoded_data = tokenizer.batch_encode_plus(np.array([jp_review_text]).astype('object'), | |
add_special_tokens=True, | |
return_attention_mask=True, | |
padding=True, | |
max_length=200, | |
return_tensors='pt', | |
truncation=True) | |
input_ids = encoded_data['input_ids'] | |
attention_masks = encoded_data['attention_mask'] | |
input_ids_column, attention_masks_column = st.columns(2) | |
with input_ids_column: | |
input_ids_expander = st.expander("Expand to see the input IDs tensor") | |
with input_ids_expander: | |
st.write(input_ids) | |
with attention_masks_column: | |
attention_masks_expander = st.expander("Expand to see the attention mask tensor") | |
with attention_masks_expander: | |
st.write(attention_masks) | |
st.markdown( | |
"<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Predict Sentiment of review using Fine-Tuned Japanese BERT<b></h3>", | |
unsafe_allow_html=True) | |
label_dict = {'positive': 1, 'negative': 0} | |
if st.button("Predict Sentiment"): | |
with st.spinner("Wait.."): | |
predictions = [] | |
model = BertForSequenceClassification.from_pretrained("shubh2014shiv/jp_review_sentiments_amzn", | |
num_labels=len(label_dict), | |
output_attentions=False, | |
output_hidden_states=False) | |
#model.load_state_dict( | |
# torch.load(JAPANESE_SENTIMENT_PROJECT_PATH + 'FineTuneJapaneseBert_AmazonReviewSentiments.pt', | |
# map_location=torch.device('cpu'))) | |
inputs = { | |
'input_ids': input_ids, | |
'attention_mask': attention_masks | |
} | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
logits = logits.detach().cpu().numpy() | |
scores = 1 / (1 + np.exp(-1 * logits)) | |
result = {"TEXT": jp_review_text,'NEGATIVE': scores[0][0], 'POSITIVE': scores[0][1]} | |
result_col,graph_col = st.columns(2) | |
with result_col: | |
st.write(result) | |
with graph_col: | |
fig = px.bar(x=['NEGATIVE','POSITIVE'],y=[result['NEGATIVE'],result['POSITIVE']]) | |
fig.update_layout(title="Probability distribution of Sentiment for the given text",\ | |
yaxis_title="Probability") | |
fig.update_traces(marker_color=['#FF7F7F','#32CD32']) | |
st.plotly_chart(fig) | |