shubh2014shiv commited on
Commit
18f8de6
1 Parent(s): 62cffe7

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +172 -0
app.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import plotly.express as px
4
+ import plotly.graph_objects as go
5
+ from st_aggrid import AgGrid
6
+ from st_aggrid.grid_options_builder import GridOptionsBuilder
7
+ from st_aggrid.shared import JsCode
8
+ from st_aggrid.shared import GridUpdateMode
9
+ from transformers import T5Tokenizer, BertForSequenceClassification
10
+ import torch
11
+ import numpy as np
12
+
13
+ st.set_page_config(layout="wide")
14
+ st.title("Project - Japanese Natural Language Processing (自然言語処理) using Transformers")
15
+ st.sidebar.subheader("自然言語処理 トピック")
16
+ topic = st.sidebar.radio(label="Select the NLP project topics", options=["Sentiment Analysis"])
17
+
18
+ st.write("-" * 5)
19
+ jp_review_text = None
20
+ #JAPANESE_SENTIMENT_PROJECT_PATH = './Japanese Amazon reviews sentiments/'
21
+
22
+ if topic == "Sentiment Analysis":
23
+ st.markdown(
24
+ "<h2 style='text-align: left; color:#EE82EE; font-size:25px;'><b>Transfer Learning based Japanese Sentiments Analysis using BERT<b></h2>",
25
+ unsafe_allow_html=True)
26
+ st.markdown(
27
+ "<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Japanese Amazon Reviews Data (日本のAmazonレビューデータ)<b></h3>",
28
+ unsafe_allow_html=True)
29
+
30
+ amazon_jp_reviews = pd.read_csv("review_val.csv").sample(frac=1,random_state=10).iloc[:16000]
31
+
32
+ cellstyle_jscode = JsCode(
33
+ """
34
+ function(params) {
35
+ if (params.value.includes('positive')) {
36
+ return {
37
+ 'color': 'black',
38
+ 'backgroundColor': '#32CD32'
39
+ }
40
+ } else {
41
+ return {
42
+ 'color': 'black',
43
+ 'backgroundColor': '#FF7F7F'
44
+ }
45
+ }
46
+ };
47
+ """
48
+ )
49
+ st.write('<style>div.row-widget.stRadio > div{flex-direction:row;justify-content: center;} </style>',
50
+ unsafe_allow_html=True)
51
+
52
+ st.write('<style>div.st-bf{flex-direction:column;} div.st-ag{font-weight:bold;padding-left:2px;}</style>',
53
+ unsafe_allow_html=True)
54
+
55
+ choose = st.radio("", ("Choose a review from the dataframe below", "Manually write review"))
56
+
57
+ SELECT_ONE_REVIEW = "Choose a review from the dataframe below"
58
+ WRITE_REVIEW = "Manually write review"
59
+
60
+ gb = GridOptionsBuilder.from_dataframe(amazon_jp_reviews)
61
+ gb.configure_column("sentiment", cellStyle=cellstyle_jscode)
62
+ gb.configure_pagination()
63
+ if choose == SELECT_ONE_REVIEW:
64
+ gb.configure_selection(selection_mode="single", use_checkbox=True, suppressRowDeselection=False)
65
+ gridOptions = gb.build()
66
+
67
+ if choose == SELECT_ONE_REVIEW:
68
+ jp_review_choice = AgGrid(amazon_jp_reviews, gridOptions=gridOptions, theme='material',
69
+ enable_enterprise_modules=True,
70
+ allow_unsafe_jscode=True, update_mode=GridUpdateMode.SELECTION_CHANGED)
71
+ st.info("Select any one the Japanese Reviews by clicking the checkbox. Reviews can be navigated from each page.")
72
+ if len(jp_review_choice['selected_rows']) != 0:
73
+ jp_review_text = jp_review_choice['selected_rows'][0]['review']
74
+ st.markdown(
75
+ "<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Selected Review in JSON (JSONで選択されたレビュー)<b></h3>",
76
+ unsafe_allow_html=True)
77
+ st.write(jp_review_choice['selected_rows'])
78
+
79
+ if choose == WRITE_REVIEW:
80
+
81
+ AgGrid(amazon_jp_reviews, gridOptions=gridOptions, theme='material',
82
+ enable_enterprise_modules=True,
83
+ allow_unsafe_jscode=True)
84
+ with open("test_reviews_jp.csv", "rb") as file:
85
+ st.download_button(label="Download Additional Japanese Reviews", data=file,
86
+ file_name="Additional Japanese Reviews.csv")
87
+ st.info("Additional subset of Japanese Reviews can be downloaded and any review can be copied & pasted in text area.")
88
+ sample_japanese_review_input = "子供のレッスンバッグ用に購入。 思ったより大きく、ピアノ教本を入れるには充分でした。中は汚れてました。 何より驚いたのは、商品の梱包。 2つ折は許せるが、透明ビニール袋の底思いっきり空いてますけど? 何これ?包むっていうか挟んで終わり?底が全開している。 引っ張れば誰でも中身の注文書も、商品も見れる状態って何なの? 個人情報が晒されて、商品も粗末な扱いで嫌な気持ちでした。 郵送で中身が無事のが奇跡じゃないでしょうか? ありえない"
89
+ jp_review_text = st.text_area(label="Press 'Ctrl+Enter' after writing review in below text area",
90
+ value=sample_japanese_review_input)
91
+ if len(jp_review_text) == 0:
92
+ st.error("Input text cannot empty. Either write the japanese review in text area manually or select the review from the grid.")
93
+
94
+ if jp_review_text:
95
+ st.markdown(
96
+ "<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Sentence-Piece based Japanese Tokenizer using RoBERTA<b></h3>",
97
+ unsafe_allow_html=True)
98
+ tokens_column, tokenID_column = st.columns(2)
99
+ tokenizer = T5Tokenizer.from_pretrained('rinna/japanese-roberta-base')
100
+ tokens = tokenizer.tokenize(jp_review_text)
101
+ token_ids = tokenizer.convert_tokens_to_ids(tokens)
102
+ with tokens_column:
103
+ token_expander = st.expander("Expand to see the tokens", expanded=False)
104
+ with token_expander:
105
+ st.write(tokens)
106
+ with tokenID_column:
107
+ tokenID_expander = st.expander("Expand to see the token IDs", expanded=False)
108
+ with tokenID_expander:
109
+ st.write(token_ids)
110
+
111
+ st.markdown(
112
+ "<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Encoded Japanese Review Text to get Input IDs and attention masks as PyTorch Tensor<b></h3>",
113
+ unsafe_allow_html=True)
114
+ encoded_data = tokenizer.batch_encode_plus(np.array([jp_review_text]).astype('object'),
115
+ add_special_tokens=True,
116
+ return_attention_mask=True,
117
+ padding=True,
118
+ max_length=200,
119
+ return_tensors='pt',
120
+ truncation=True)
121
+ input_ids = encoded_data['input_ids']
122
+ attention_masks = encoded_data['attention_mask']
123
+ input_ids_column, attention_masks_column = st.columns(2)
124
+ with input_ids_column:
125
+ input_ids_expander = st.expander("Expand to see the input IDs tensor")
126
+ with input_ids_expander:
127
+ st.write(input_ids)
128
+ with attention_masks_column:
129
+ attention_masks_expander = st.expander("Expand to see the attention mask tensor")
130
+ with attention_masks_expander:
131
+ st.write(attention_masks)
132
+
133
+ st.markdown(
134
+ "<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Predict Sentiment of review using Fine-Tuned Japanese BERT<b></h3>",
135
+ unsafe_allow_html=True)
136
+
137
+ label_dict = {'positive': 1, 'negative': 0}
138
+ if st.button("Predict Sentiment"):
139
+ with st.spinner("Wait.."):
140
+ predictions = []
141
+ model = BertForSequenceClassification.from_pretrained("shubh2014shiv/jp_review_sentiments_amzn",
142
+ num_labels=len(label_dict),
143
+ output_attentions=False,
144
+ output_hidden_states=False)
145
+ #model.load_state_dict(
146
+ # torch.load(JAPANESE_SENTIMENT_PROJECT_PATH + 'FineTuneJapaneseBert_AmazonReviewSentiments.pt',
147
+ # map_location=torch.device('cpu')))
148
+
149
+ inputs = {
150
+ 'input_ids': input_ids,
151
+ 'attention_mask': attention_masks
152
+ }
153
+
154
+ with torch.no_grad():
155
+ outputs = model(**inputs)
156
+
157
+ logits = outputs.logits
158
+ logits = logits.detach().cpu().numpy()
159
+ scores = 1 / (1 + np.exp(-1 * logits))
160
+
161
+ result = {"TEXT": jp_review_text,'NEGATIVE': scores[0][0], 'POSITIVE': scores[0][1]}
162
+
163
+ result_col,graph_col = st.columns(2)
164
+ with result_col:
165
+ st.write(result)
166
+ with graph_col:
167
+ fig = px.bar(x=['NEGATIVE','POSITIVE'],y=[result['NEGATIVE'],result['POSITIVE']])
168
+ fig.update_layout(title="Probability distribution of Sentiment for the given text",\
169
+ yaxis_title="Probability")
170
+ fig.update_traces(marker_color=['#FF7F7F','#32CD32'])
171
+ st.plotly_chart(fig)
172
+