Spaces:
Runtime error
Runtime error
shubh2014shiv
commited on
Commit
•
a40678b
1
Parent(s):
18556fd
Update app.py
Browse filesAdded comment while creating new folder for downloading the Japanese English translation model
app.py
CHANGED
@@ -1,361 +1,361 @@
|
|
1 |
-
import streamlit as st
|
2 |
-
import pandas as pd
|
3 |
-
import plotly.express as px
|
4 |
-
import plotly.graph_objects as go
|
5 |
-
from st_aggrid import AgGrid
|
6 |
-
from st_aggrid.grid_options_builder import GridOptionsBuilder
|
7 |
-
from st_aggrid.shared import JsCode
|
8 |
-
from st_aggrid.shared import GridUpdateMode
|
9 |
-
from transformers import T5Tokenizer, BertForSequenceClassification,AutoTokenizer, AutoModelForSeq2SeqLM
|
10 |
-
import torch
|
11 |
-
import numpy as np
|
12 |
-
import json
|
13 |
-
from transformers import AutoTokenizer, BertTokenizer, AutoModelWithLMHead
|
14 |
-
import pytorch_lightning as pl
|
15 |
-
from pathlib import Path
|
16 |
-
|
17 |
-
# Defining some functions for caching purpose by streamlit
|
18 |
-
class TranslationModel(pl.LightningModule):
|
19 |
-
def __init__(self):
|
20 |
-
super().__init__()
|
21 |
-
self.model = AutoModelWithLMHead.from_pretrained("Helsinki-NLP/opus-mt-ja-en", return_dict=True)
|
22 |
-
|
23 |
-
|
24 |
-
@st.experimental_singleton
|
25 |
-
def loadFineTunedJaEn_NMT_Model():
|
26 |
-
save_dest = Path('model')
|
27 |
-
save_dest.mkdir(exist_ok=True)
|
28 |
-
|
29 |
-
f_checkpoint = Path("model/best-checkpoint.ckpt")
|
30 |
-
|
31 |
-
if not f_checkpoint.exists():
|
32 |
-
with st.spinner("Downloading model.This may take a while! \n Don't refresh or close this page!"):
|
33 |
-
from GD_download import download_file_from_google_drive
|
34 |
-
download_file_from_google_drive('1CZQKGj9hSqj7kEuJp_jm7bNVXrbcFsgP', f_checkpoint)
|
35 |
-
|
36 |
-
trained_model = TranslationModel.load_from_checkpoint(f_checkpoint)
|
37 |
-
|
38 |
-
return trained_model
|
39 |
-
|
40 |
-
@st.experimental_singleton
|
41 |
-
def getJpEn_Tokenizers():
|
42 |
-
try:
|
43 |
-
with st.spinner("Downloading English and Japanese Transformer Tokenizers"):
|
44 |
-
ja_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ja-en")
|
45 |
-
en_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
46 |
-
except:
|
47 |
-
st.error("Issue with downloading tokenizers")
|
48 |
-
|
49 |
-
return ja_tokenizer, en_tokenizer
|
50 |
-
|
51 |
-
st.set_page_config(layout="wide")
|
52 |
-
st.title("Project - Japanese Natural Language Processing (自然言語処理) using Transformers")
|
53 |
-
st.sidebar.subheader("自然言語処理 トピック")
|
54 |
-
topic = st.sidebar.radio(label="Select the NLP project topics", options=["Sentiment Analysis","Text Summarization","Japanese to English Translation"])
|
55 |
-
|
56 |
-
st.write("-" * 5)
|
57 |
-
jp_review_text = None
|
58 |
-
#JAPANESE_SENTIMENT_PROJECT_PATH = './Japanese Amazon reviews sentiments/'
|
59 |
-
|
60 |
-
if topic == "Sentiment Analysis":
|
61 |
-
st.markdown(
|
62 |
-
"<h2 style='text-align: left; color:#EE82EE; font-size:25px;'><b>Transfer Learning based Japanese Sentiments Analysis using BERT<b></h2>",
|
63 |
-
unsafe_allow_html=True)
|
64 |
-
st.markdown(
|
65 |
-
"<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Japanese Amazon Reviews Data (日本のAmazonレビューデータ)<b></h3>",
|
66 |
-
unsafe_allow_html=True)
|
67 |
-
|
68 |
-
amazon_jp_reviews = pd.read_csv("review_val.csv").sample(frac=1,random_state=10).iloc[:16000]
|
69 |
-
|
70 |
-
cellstyle_jscode = JsCode(
|
71 |
-
"""
|
72 |
-
function(params) {
|
73 |
-
if (params.value.includes('positive')) {
|
74 |
-
return {
|
75 |
-
'color': 'black',
|
76 |
-
'backgroundColor': '#32CD32'
|
77 |
-
}
|
78 |
-
} else {
|
79 |
-
return {
|
80 |
-
'color': 'black',
|
81 |
-
'backgroundColor': '#FF7F7F'
|
82 |
-
}
|
83 |
-
}
|
84 |
-
};
|
85 |
-
"""
|
86 |
-
)
|
87 |
-
st.write('<style>div.row-widget.stRadio > div{flex-direction:row;justify-content: center;} </style>',
|
88 |
-
unsafe_allow_html=True)
|
89 |
-
|
90 |
-
st.write('<style>div.st-bf{flex-direction:column;} div.st-ag{font-weight:bold;padding-left:2px;}</style>',
|
91 |
-
unsafe_allow_html=True)
|
92 |
-
|
93 |
-
choose = st.radio("", ("Choose a review from the dataframe below", "Manually write review"))
|
94 |
-
|
95 |
-
SELECT_ONE_REVIEW = "Choose a review from the dataframe below"
|
96 |
-
WRITE_REVIEW = "Manually write review"
|
97 |
-
|
98 |
-
gb = GridOptionsBuilder.from_dataframe(amazon_jp_reviews)
|
99 |
-
gb.configure_column("sentiment", cellStyle=cellstyle_jscode)
|
100 |
-
gb.configure_pagination()
|
101 |
-
if choose == SELECT_ONE_REVIEW:
|
102 |
-
gb.configure_selection(selection_mode="single", use_checkbox=True, suppressRowDeselection=False)
|
103 |
-
gridOptions = gb.build()
|
104 |
-
|
105 |
-
if choose == SELECT_ONE_REVIEW:
|
106 |
-
jp_review_choice = AgGrid(amazon_jp_reviews, gridOptions=gridOptions, theme='material',
|
107 |
-
enable_enterprise_modules=True,
|
108 |
-
allow_unsafe_jscode=True, update_mode=GridUpdateMode.SELECTION_CHANGED)
|
109 |
-
st.info("Select any one the Japanese Reviews by clicking the checkbox. Reviews can be navigated from each page.")
|
110 |
-
if len(jp_review_choice['selected_rows']) != 0:
|
111 |
-
jp_review_text = jp_review_choice['selected_rows'][0]['review']
|
112 |
-
st.markdown(
|
113 |
-
"<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Selected Review in JSON (JSONで選択されたレビュー)<b></h3>",
|
114 |
-
unsafe_allow_html=True)
|
115 |
-
st.write(jp_review_choice['selected_rows'])
|
116 |
-
|
117 |
-
if choose == WRITE_REVIEW:
|
118 |
-
|
119 |
-
AgGrid(amazon_jp_reviews, gridOptions=gridOptions, theme='material',
|
120 |
-
enable_enterprise_modules=True,
|
121 |
-
allow_unsafe_jscode=True)
|
122 |
-
with open("test_reviews_jp.csv", "rb") as file:
|
123 |
-
st.download_button(label="Download Additional Japanese Reviews", data=file,
|
124 |
-
file_name="Additional Japanese Reviews.csv")
|
125 |
-
st.info("Additional subset of Japanese Reviews can be downloaded and any review can be copied & pasted in text area.")
|
126 |
-
sample_japanese_review_input = "子供のレッスンバッグ用に購入。 思ったより大きく、ピアノ教本を入れるには充分でした。中は汚れてました。 何より驚いたのは、商品の梱包。 2つ折は許せるが、透明ビニール袋の底思いっきり空いてますけど? 何これ?包むっていうか挟んで終わり?底が全開している。 引っ張れば誰でも中身の注文書も、商品も見れる状態って何なの? 個人情報が晒されて、商品も粗末な扱いで嫌な気持ちでした。 郵送で中身が無事のが奇跡じゃないでしょうか? ありえない"
|
127 |
-
jp_review_text = st.text_area(label="Press 'Ctrl+Enter' after writing review in below text area",
|
128 |
-
value=sample_japanese_review_input)
|
129 |
-
if len(jp_review_text) == 0:
|
130 |
-
st.error("Input text cannot empty. Either write the japanese review in text area manually or select the review from the grid.")
|
131 |
-
|
132 |
-
if jp_review_text:
|
133 |
-
st.markdown(
|
134 |
-
"<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Sentence-Piece based Japanese Tokenizer using RoBERTA<b></h3>",
|
135 |
-
unsafe_allow_html=True)
|
136 |
-
tokens_column, tokenID_column = st.columns(2)
|
137 |
-
tokenizer = T5Tokenizer.from_pretrained('rinna/japanese-roberta-base')
|
138 |
-
tokens = tokenizer.tokenize(jp_review_text)
|
139 |
-
token_ids = tokenizer.convert_tokens_to_ids(tokens)
|
140 |
-
with tokens_column:
|
141 |
-
token_expander = st.expander("Expand to see the tokens", expanded=False)
|
142 |
-
with token_expander:
|
143 |
-
st.write(tokens)
|
144 |
-
with tokenID_column:
|
145 |
-
tokenID_expander = st.expander("Expand to see the token IDs", expanded=False)
|
146 |
-
with tokenID_expander:
|
147 |
-
st.write(token_ids)
|
148 |
-
|
149 |
-
st.markdown(
|
150 |
-
"<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Encoded Japanese Review Text to get Input IDs and attention masks as PyTorch Tensor<b></h3>",
|
151 |
-
unsafe_allow_html=True)
|
152 |
-
encoded_data = tokenizer.batch_encode_plus(np.array([jp_review_text]).astype('object'),
|
153 |
-
add_special_tokens=True,
|
154 |
-
return_attention_mask=True,
|
155 |
-
padding=True,
|
156 |
-
max_length=200,
|
157 |
-
return_tensors='pt',
|
158 |
-
truncation=True)
|
159 |
-
input_ids = encoded_data['input_ids']
|
160 |
-
attention_masks = encoded_data['attention_mask']
|
161 |
-
input_ids_column, attention_masks_column = st.columns(2)
|
162 |
-
with input_ids_column:
|
163 |
-
input_ids_expander = st.expander("Expand to see the input IDs tensor")
|
164 |
-
with input_ids_expander:
|
165 |
-
st.write(input_ids)
|
166 |
-
with attention_masks_column:
|
167 |
-
attention_masks_expander = st.expander("Expand to see the attention mask tensor")
|
168 |
-
with attention_masks_expander:
|
169 |
-
st.write(attention_masks)
|
170 |
-
|
171 |
-
st.markdown(
|
172 |
-
"<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Predict Sentiment of review using Fine-Tuned Japanese BERT<b></h3>",
|
173 |
-
unsafe_allow_html=True)
|
174 |
-
|
175 |
-
label_dict = {'positive': 1, 'negative': 0}
|
176 |
-
if st.button("Predict Sentiment"):
|
177 |
-
with st.spinner("Wait.."):
|
178 |
-
predictions = []
|
179 |
-
model = BertForSequenceClassification.from_pretrained("shubh2014shiv/jp_review_sentiments_amzn",
|
180 |
-
num_labels=len(label_dict),
|
181 |
-
output_attentions=False,
|
182 |
-
output_hidden_states=False)
|
183 |
-
#model.load_state_dict(
|
184 |
-
# torch.load(JAPANESE_SENTIMENT_PROJECT_PATH + 'FineTuneJapaneseBert_AmazonReviewSentiments.pt',
|
185 |
-
# map_location=torch.device('cpu')))
|
186 |
-
|
187 |
-
model.load_state_dict(
|
188 |
-
torch.load('reviewSentiments_jp.pt',
|
189 |
-
map_location=torch.device('cpu')))
|
190 |
-
|
191 |
-
inputs = {
|
192 |
-
'input_ids': input_ids,
|
193 |
-
'attention_mask': attention_masks
|
194 |
-
}
|
195 |
-
|
196 |
-
with torch.no_grad():
|
197 |
-
outputs = model(**inputs)
|
198 |
-
|
199 |
-
logits = outputs.logits
|
200 |
-
logits = logits.detach().cpu().numpy()
|
201 |
-
scores = 1 / (1 + np.exp(-1 * logits))
|
202 |
-
|
203 |
-
result = {"TEXT (文章)": jp_review_text,'NEGATIVE (ネガティブ)': scores[0][0], 'POSITIVE (ポジティブ)': scores[0][1]}
|
204 |
-
|
205 |
-
result_col,graph_col = st.columns(2)
|
206 |
-
with result_col:
|
207 |
-
st.write(result)
|
208 |
-
with graph_col:
|
209 |
-
fig = px.bar(x=['NEGATIVE (ネガティブ)','POSITIVE (ポジティブ)'],y=[result['NEGATIVE (ネガティブ)'],result['POSITIVE (ポジティブ)']])
|
210 |
-
fig.update_layout(title="Probability distribution of Sentiment for the given text",\
|
211 |
-
yaxis_title="Probability (確率)")
|
212 |
-
fig.update_traces(marker_color=['#FF7F7F','#32CD32'])
|
213 |
-
st.plotly_chart(fig)
|
214 |
-
|
215 |
-
elif topic == "Text Summarization":
|
216 |
-
st.markdown(
|
217 |
-
"<h2 style='text-align: left; color:#EE82EE; font-size:25px;'><b>Summarizing Japanese News Article using multi-Lingual T5 (mT5)<b></h2>",
|
218 |
-
unsafe_allow_html=True)
|
219 |
-
st.markdown(
|
220 |
-
"<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Japanese News Article Data<b></h3>",
|
221 |
-
unsafe_allow_html=True)
|
222 |
-
|
223 |
-
news_articles = pd.read_csv("jp_news_articles_val.csv").sample(frac=0.75,
|
224 |
-
random_state=42)
|
225 |
-
gb = GridOptionsBuilder.from_dataframe(news_articles)
|
226 |
-
gb.configure_pagination()
|
227 |
-
gb.configure_selection(selection_mode="single", use_checkbox=True, suppressRowDeselection=False)
|
228 |
-
gridOptions = gb.build()
|
229 |
-
jp_article = AgGrid(news_articles, gridOptions=gridOptions, theme='material',
|
230 |
-
enable_enterprise_modules=True,
|
231 |
-
allow_unsafe_jscode=True, update_mode=GridUpdateMode.SELECTION_CHANGED)
|
232 |
-
|
233 |
-
# WHITESPACE_HANDLER = lambda k: re.sub('\s+', ' ', re.sub('\n+', ' ', k.strip()))
|
234 |
-
if len(jp_article['selected_rows']) == 0:
|
235 |
-
st.info("Pick any one Japanese News Article by selecting the checkbox. News articles can be navigated by clicking on page navigator at right-bottom")
|
236 |
-
else:
|
237 |
-
article_text = jp_article['selected_rows'][0]['News Articles']
|
238 |
-
|
239 |
-
text = st.text_area(label="Text from selected Japanese News Article(ニュース記事)", value=article_text, height=500)
|
240 |
-
summary_length = st.slider(label="Select the maximum length of summary (要約の最大長を選択します )", min_value=120,max_value=160,step=5)
|
241 |
-
|
242 |
-
if text and st.button("Summarize it! (要約しよう)"):
|
243 |
-
waitPlaceholder = st.image("wait.gif")
|
244 |
-
summarization_model_name = "csebuetnlp/mT5_multilingual_XLSum"
|
245 |
-
tokenizer = AutoTokenizer.from_pretrained(summarization_model_name )
|
246 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(summarization_model_name )
|
247 |
-
|
248 |
-
input_ids = tokenizer(
|
249 |
-
article_text,
|
250 |
-
return_tensors="pt",
|
251 |
-
padding="max_length",
|
252 |
-
truncation=True,
|
253 |
-
max_length=512
|
254 |
-
)["input_ids"]
|
255 |
-
|
256 |
-
output_ids = model.generate(
|
257 |
-
input_ids=input_ids,
|
258 |
-
max_length=summary_length,
|
259 |
-
no_repeat_ngram_size=2,
|
260 |
-
num_beams=4
|
261 |
-
)[0]
|
262 |
-
|
263 |
-
summary = tokenizer.decode(
|
264 |
-
output_ids,
|
265 |
-
skip_special_tokens=True,
|
266 |
-
clean_up_tokenization_spaces=False
|
267 |
-
)
|
268 |
-
|
269 |
-
waitPlaceholder.empty()
|
270 |
-
|
271 |
-
st.markdown(
|
272 |
-
"<h2 style='text-align: left; color:#32CD32; font-size:25px;'><b>Summary (要約文)<b></h2>",
|
273 |
-
unsafe_allow_html=True)
|
274 |
-
|
275 |
-
st.write(summary)
|
276 |
-
elif topic == "Japanese to English Translation":
|
277 |
-
st.markdown(
|
278 |
-
"<h2 style='text-align: left; color:#EE82EE; font-size:25px;'><b>Japanese to English translation (for short sentences)<b></h2>",
|
279 |
-
unsafe_allow_html=True)
|
280 |
-
st.markdown(
|
281 |
-
"<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Business Scene Dialog Japanese-English Corpus<b></h3>",
|
282 |
-
unsafe_allow_html=True)
|
283 |
-
|
284 |
-
st.write("Below given Japanese-English pair is from 'Business Scene Dialog Corpus' by the University of Tokyo")
|
285 |
-
link = '[Corpus GitHub Link](https://github.com/tsuruoka-lab/BSD)'
|
286 |
-
st.markdown(link, unsafe_allow_html=True)
|
287 |
-
|
288 |
-
bsd_more_info = st.expander(label="Expand to get more information on data and training report")
|
289 |
-
with bsd_more_info:
|
290 |
-
st.markdown(
|
291 |
-
"<h3 style='text-align: left; color:#F63366; font-size:12px;'><b>Training Dataset<b></h3>",
|
292 |
-
unsafe_allow_html=True)
|
293 |
-
st.write("The corpus has total 20,000 Japanese-English Business Dialog pairs. The fined-tuned Transformer model is validated on 670 Japanese-English Business Dialog pairs")
|
294 |
-
|
295 |
-
st.markdown(
|
296 |
-
"<h3 style='text-align: left; color:#F63366; font-size:12px;'><b>Training Report<b></h3>",
|
297 |
-
unsafe_allow_html=True)
|
298 |
-
st.write(
|
299 |
-
"The Dashboard for training result on Tensorboard is [here](https://tensorboard.dev/experiment/eWhxt1i2RuaU64krYtORhw/)")
|
300 |
-
|
301 |
-
with open("./BSD_ja-en_val.json", encoding='utf-8') as f:
|
302 |
-
bsd_sample_data = json.load(f)
|
303 |
-
|
304 |
-
en, ja = [], []
|
305 |
-
for i in range(len(bsd_sample_data)):
|
306 |
-
for j in range(len(bsd_sample_data[i]['conversation'])):
|
307 |
-
en.append(bsd_sample_data[i]['conversation'][j]['en_sentence'])
|
308 |
-
ja.append(bsd_sample_data[i]['conversation'][j]['ja_sentence'])
|
309 |
-
|
310 |
-
df = pd.DataFrame.from_dict({'Japanese': ja, 'English': en})
|
311 |
-
gb = GridOptionsBuilder.from_dataframe(df)
|
312 |
-
gb.configure_pagination()
|
313 |
-
gb.configure_selection(selection_mode="single", use_checkbox=True, suppressRowDeselection=False)
|
314 |
-
gridOptions = gb.build()
|
315 |
-
translation_text = AgGrid(df, gridOptions=gridOptions, theme='material',
|
316 |
-
enable_enterprise_modules=True,
|
317 |
-
allow_unsafe_jscode=True, update_mode=GridUpdateMode.SELECTION_CHANGED)
|
318 |
-
if len(translation_text['selected_rows']) != 0:
|
319 |
-
bsd_jp = translation_text['selected_rows'][0]['Japanese']
|
320 |
-
st.markdown(
|
321 |
-
"<h2 style='text-align: left; color:#32CD32; font-size:25px;'><b>Business Scene Dialog in Japanese (日本語でのビジネスシーンダイアログ)<b></h2>",
|
322 |
-
unsafe_allow_html=True)
|
323 |
-
st.write(bsd_jp)
|
324 |
-
|
325 |
-
if st.button("Translate"):
|
326 |
-
ja_tokenizer, en_tokenizer = getJpEn_Tokenizers()
|
327 |
-
trained_model = loadFineTunedJaEn_NMT_Model()
|
328 |
-
trained_model.freeze()
|
329 |
-
|
330 |
-
|
331 |
-
def translate(text):
|
332 |
-
text_encoding = ja_tokenizer(
|
333 |
-
text,
|
334 |
-
max_length=100,
|
335 |
-
padding="max_length",
|
336 |
-
truncation=True,
|
337 |
-
return_attention_mask=True,
|
338 |
-
add_special_tokens=True,
|
339 |
-
return_tensors='pt'
|
340 |
-
)
|
341 |
-
|
342 |
-
generated_ids = trained_model.model.generate(
|
343 |
-
input_ids=text_encoding['input_ids'],
|
344 |
-
attention_mask=text_encoding['attention_mask'],
|
345 |
-
max_length=100,
|
346 |
-
num_beams=2,
|
347 |
-
repetition_penalty=2.5,
|
348 |
-
length_penalty=1.0,
|
349 |
-
early_stopping=True
|
350 |
-
)
|
351 |
-
|
352 |
-
preds = [en_tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True) for
|
353 |
-
gen_id in generated_ids]
|
354 |
-
|
355 |
-
return "".join(preds)[5:]
|
356 |
-
|
357 |
-
|
358 |
-
st.markdown(
|
359 |
-
"<h2 style='text-align: left; color:#32CD32; font-size:25px;'><b>Translated Dialog in English (英語の翻訳されたダイアログ)<b></h2>",
|
360 |
-
unsafe_allow_html=True)
|
361 |
-
st.write(translate(bsd_jp))
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import plotly.express as px
|
4 |
+
import plotly.graph_objects as go
|
5 |
+
from st_aggrid import AgGrid
|
6 |
+
from st_aggrid.grid_options_builder import GridOptionsBuilder
|
7 |
+
from st_aggrid.shared import JsCode
|
8 |
+
from st_aggrid.shared import GridUpdateMode
|
9 |
+
from transformers import T5Tokenizer, BertForSequenceClassification,AutoTokenizer, AutoModelForSeq2SeqLM
|
10 |
+
import torch
|
11 |
+
import numpy as np
|
12 |
+
import json
|
13 |
+
from transformers import AutoTokenizer, BertTokenizer, AutoModelWithLMHead
|
14 |
+
import pytorch_lightning as pl
|
15 |
+
from pathlib import Path
|
16 |
+
|
17 |
+
# Defining some functions for caching purpose by streamlit
|
18 |
+
class TranslationModel(pl.LightningModule):
|
19 |
+
def __init__(self):
|
20 |
+
super().__init__()
|
21 |
+
self.model = AutoModelWithLMHead.from_pretrained("Helsinki-NLP/opus-mt-ja-en", return_dict=True)
|
22 |
+
|
23 |
+
|
24 |
+
@st.experimental_singleton
|
25 |
+
def loadFineTunedJaEn_NMT_Model():
|
26 |
+
save_dest = Path('model')
|
27 |
+
save_dest.mkdir(exist_ok=True)
|
28 |
+
st.write("Creating new folder for downloading the Japanese to English Translation Model. ")
|
29 |
+
f_checkpoint = Path("model/best-checkpoint.ckpt")
|
30 |
+
st.write("'Folder: model/best-checkpoint.ckpt' created.")
|
31 |
+
if not f_checkpoint.exists():
|
32 |
+
with st.spinner("Downloading model.This may take a while! \n Don't refresh or close this page!"):
|
33 |
+
from GD_download import download_file_from_google_drive
|
34 |
+
download_file_from_google_drive('1CZQKGj9hSqj7kEuJp_jm7bNVXrbcFsgP', f_checkpoint)
|
35 |
+
|
36 |
+
trained_model = TranslationModel.load_from_checkpoint(f_checkpoint)
|
37 |
+
|
38 |
+
return trained_model
|
39 |
+
|
40 |
+
@st.experimental_singleton
|
41 |
+
def getJpEn_Tokenizers():
|
42 |
+
try:
|
43 |
+
with st.spinner("Downloading English and Japanese Transformer Tokenizers"):
|
44 |
+
ja_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ja-en")
|
45 |
+
en_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
46 |
+
except:
|
47 |
+
st.error("Issue with downloading tokenizers")
|
48 |
+
|
49 |
+
return ja_tokenizer, en_tokenizer
|
50 |
+
|
51 |
+
st.set_page_config(layout="wide")
|
52 |
+
st.title("Project - Japanese Natural Language Processing (自然言語処理) using Transformers")
|
53 |
+
st.sidebar.subheader("自然言語処理 トピック")
|
54 |
+
topic = st.sidebar.radio(label="Select the NLP project topics", options=["Sentiment Analysis","Text Summarization","Japanese to English Translation"])
|
55 |
+
|
56 |
+
st.write("-" * 5)
|
57 |
+
jp_review_text = None
|
58 |
+
#JAPANESE_SENTIMENT_PROJECT_PATH = './Japanese Amazon reviews sentiments/'
|
59 |
+
|
60 |
+
if topic == "Sentiment Analysis":
|
61 |
+
st.markdown(
|
62 |
+
"<h2 style='text-align: left; color:#EE82EE; font-size:25px;'><b>Transfer Learning based Japanese Sentiments Analysis using BERT<b></h2>",
|
63 |
+
unsafe_allow_html=True)
|
64 |
+
st.markdown(
|
65 |
+
"<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Japanese Amazon Reviews Data (日本のAmazonレビューデータ)<b></h3>",
|
66 |
+
unsafe_allow_html=True)
|
67 |
+
|
68 |
+
amazon_jp_reviews = pd.read_csv("review_val.csv").sample(frac=1,random_state=10).iloc[:16000]
|
69 |
+
|
70 |
+
cellstyle_jscode = JsCode(
|
71 |
+
"""
|
72 |
+
function(params) {
|
73 |
+
if (params.value.includes('positive')) {
|
74 |
+
return {
|
75 |
+
'color': 'black',
|
76 |
+
'backgroundColor': '#32CD32'
|
77 |
+
}
|
78 |
+
} else {
|
79 |
+
return {
|
80 |
+
'color': 'black',
|
81 |
+
'backgroundColor': '#FF7F7F'
|
82 |
+
}
|
83 |
+
}
|
84 |
+
};
|
85 |
+
"""
|
86 |
+
)
|
87 |
+
st.write('<style>div.row-widget.stRadio > div{flex-direction:row;justify-content: center;} </style>',
|
88 |
+
unsafe_allow_html=True)
|
89 |
+
|
90 |
+
st.write('<style>div.st-bf{flex-direction:column;} div.st-ag{font-weight:bold;padding-left:2px;}</style>',
|
91 |
+
unsafe_allow_html=True)
|
92 |
+
|
93 |
+
choose = st.radio("", ("Choose a review from the dataframe below", "Manually write review"))
|
94 |
+
|
95 |
+
SELECT_ONE_REVIEW = "Choose a review from the dataframe below"
|
96 |
+
WRITE_REVIEW = "Manually write review"
|
97 |
+
|
98 |
+
gb = GridOptionsBuilder.from_dataframe(amazon_jp_reviews)
|
99 |
+
gb.configure_column("sentiment", cellStyle=cellstyle_jscode)
|
100 |
+
gb.configure_pagination()
|
101 |
+
if choose == SELECT_ONE_REVIEW:
|
102 |
+
gb.configure_selection(selection_mode="single", use_checkbox=True, suppressRowDeselection=False)
|
103 |
+
gridOptions = gb.build()
|
104 |
+
|
105 |
+
if choose == SELECT_ONE_REVIEW:
|
106 |
+
jp_review_choice = AgGrid(amazon_jp_reviews, gridOptions=gridOptions, theme='material',
|
107 |
+
enable_enterprise_modules=True,
|
108 |
+
allow_unsafe_jscode=True, update_mode=GridUpdateMode.SELECTION_CHANGED)
|
109 |
+
st.info("Select any one the Japanese Reviews by clicking the checkbox. Reviews can be navigated from each page.")
|
110 |
+
if len(jp_review_choice['selected_rows']) != 0:
|
111 |
+
jp_review_text = jp_review_choice['selected_rows'][0]['review']
|
112 |
+
st.markdown(
|
113 |
+
"<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Selected Review in JSON (JSONで選択されたレビュー)<b></h3>",
|
114 |
+
unsafe_allow_html=True)
|
115 |
+
st.write(jp_review_choice['selected_rows'])
|
116 |
+
|
117 |
+
if choose == WRITE_REVIEW:
|
118 |
+
|
119 |
+
AgGrid(amazon_jp_reviews, gridOptions=gridOptions, theme='material',
|
120 |
+
enable_enterprise_modules=True,
|
121 |
+
allow_unsafe_jscode=True)
|
122 |
+
with open("test_reviews_jp.csv", "rb") as file:
|
123 |
+
st.download_button(label="Download Additional Japanese Reviews", data=file,
|
124 |
+
file_name="Additional Japanese Reviews.csv")
|
125 |
+
st.info("Additional subset of Japanese Reviews can be downloaded and any review can be copied & pasted in text area.")
|
126 |
+
sample_japanese_review_input = "子供のレッスンバッグ用に購入。 思ったより大きく、ピアノ教本を入れるには充分でした。中は汚れてました。 何より驚いたのは、商品の梱包。 2つ折は許せるが、透明ビニール袋の底思いっきり空いてますけど? 何これ?包むっていうか挟んで終わり?底が全開している。 引っ張れば誰でも中身の注文書も、商品も見れる状態って何なの? 個人情報が晒されて、商品も粗末な扱いで嫌な気持ちでした。 郵送で中身が無事のが奇跡じゃないでしょうか? ありえない"
|
127 |
+
jp_review_text = st.text_area(label="Press 'Ctrl+Enter' after writing review in below text area",
|
128 |
+
value=sample_japanese_review_input)
|
129 |
+
if len(jp_review_text) == 0:
|
130 |
+
st.error("Input text cannot empty. Either write the japanese review in text area manually or select the review from the grid.")
|
131 |
+
|
132 |
+
if jp_review_text:
|
133 |
+
st.markdown(
|
134 |
+
"<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Sentence-Piece based Japanese Tokenizer using RoBERTA<b></h3>",
|
135 |
+
unsafe_allow_html=True)
|
136 |
+
tokens_column, tokenID_column = st.columns(2)
|
137 |
+
tokenizer = T5Tokenizer.from_pretrained('rinna/japanese-roberta-base')
|
138 |
+
tokens = tokenizer.tokenize(jp_review_text)
|
139 |
+
token_ids = tokenizer.convert_tokens_to_ids(tokens)
|
140 |
+
with tokens_column:
|
141 |
+
token_expander = st.expander("Expand to see the tokens", expanded=False)
|
142 |
+
with token_expander:
|
143 |
+
st.write(tokens)
|
144 |
+
with tokenID_column:
|
145 |
+
tokenID_expander = st.expander("Expand to see the token IDs", expanded=False)
|
146 |
+
with tokenID_expander:
|
147 |
+
st.write(token_ids)
|
148 |
+
|
149 |
+
st.markdown(
|
150 |
+
"<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Encoded Japanese Review Text to get Input IDs and attention masks as PyTorch Tensor<b></h3>",
|
151 |
+
unsafe_allow_html=True)
|
152 |
+
encoded_data = tokenizer.batch_encode_plus(np.array([jp_review_text]).astype('object'),
|
153 |
+
add_special_tokens=True,
|
154 |
+
return_attention_mask=True,
|
155 |
+
padding=True,
|
156 |
+
max_length=200,
|
157 |
+
return_tensors='pt',
|
158 |
+
truncation=True)
|
159 |
+
input_ids = encoded_data['input_ids']
|
160 |
+
attention_masks = encoded_data['attention_mask']
|
161 |
+
input_ids_column, attention_masks_column = st.columns(2)
|
162 |
+
with input_ids_column:
|
163 |
+
input_ids_expander = st.expander("Expand to see the input IDs tensor")
|
164 |
+
with input_ids_expander:
|
165 |
+
st.write(input_ids)
|
166 |
+
with attention_masks_column:
|
167 |
+
attention_masks_expander = st.expander("Expand to see the attention mask tensor")
|
168 |
+
with attention_masks_expander:
|
169 |
+
st.write(attention_masks)
|
170 |
+
|
171 |
+
st.markdown(
|
172 |
+
"<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Predict Sentiment of review using Fine-Tuned Japanese BERT<b></h3>",
|
173 |
+
unsafe_allow_html=True)
|
174 |
+
|
175 |
+
label_dict = {'positive': 1, 'negative': 0}
|
176 |
+
if st.button("Predict Sentiment"):
|
177 |
+
with st.spinner("Wait.."):
|
178 |
+
predictions = []
|
179 |
+
model = BertForSequenceClassification.from_pretrained("shubh2014shiv/jp_review_sentiments_amzn",
|
180 |
+
num_labels=len(label_dict),
|
181 |
+
output_attentions=False,
|
182 |
+
output_hidden_states=False)
|
183 |
+
#model.load_state_dict(
|
184 |
+
# torch.load(JAPANESE_SENTIMENT_PROJECT_PATH + 'FineTuneJapaneseBert_AmazonReviewSentiments.pt',
|
185 |
+
# map_location=torch.device('cpu')))
|
186 |
+
|
187 |
+
model.load_state_dict(
|
188 |
+
torch.load('reviewSentiments_jp.pt',
|
189 |
+
map_location=torch.device('cpu')))
|
190 |
+
|
191 |
+
inputs = {
|
192 |
+
'input_ids': input_ids,
|
193 |
+
'attention_mask': attention_masks
|
194 |
+
}
|
195 |
+
|
196 |
+
with torch.no_grad():
|
197 |
+
outputs = model(**inputs)
|
198 |
+
|
199 |
+
logits = outputs.logits
|
200 |
+
logits = logits.detach().cpu().numpy()
|
201 |
+
scores = 1 / (1 + np.exp(-1 * logits))
|
202 |
+
|
203 |
+
result = {"TEXT (文章)": jp_review_text,'NEGATIVE (ネガティブ)': scores[0][0], 'POSITIVE (ポジティブ)': scores[0][1]}
|
204 |
+
|
205 |
+
result_col,graph_col = st.columns(2)
|
206 |
+
with result_col:
|
207 |
+
st.write(result)
|
208 |
+
with graph_col:
|
209 |
+
fig = px.bar(x=['NEGATIVE (ネガティブ)','POSITIVE (ポジティブ)'],y=[result['NEGATIVE (ネガティブ)'],result['POSITIVE (ポジティブ)']])
|
210 |
+
fig.update_layout(title="Probability distribution of Sentiment for the given text",\
|
211 |
+
yaxis_title="Probability (確率)")
|
212 |
+
fig.update_traces(marker_color=['#FF7F7F','#32CD32'])
|
213 |
+
st.plotly_chart(fig)
|
214 |
+
|
215 |
+
elif topic == "Text Summarization":
|
216 |
+
st.markdown(
|
217 |
+
"<h2 style='text-align: left; color:#EE82EE; font-size:25px;'><b>Summarizing Japanese News Article using multi-Lingual T5 (mT5)<b></h2>",
|
218 |
+
unsafe_allow_html=True)
|
219 |
+
st.markdown(
|
220 |
+
"<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Japanese News Article Data<b></h3>",
|
221 |
+
unsafe_allow_html=True)
|
222 |
+
|
223 |
+
news_articles = pd.read_csv("jp_news_articles_val.csv").sample(frac=0.75,
|
224 |
+
random_state=42)
|
225 |
+
gb = GridOptionsBuilder.from_dataframe(news_articles)
|
226 |
+
gb.configure_pagination()
|
227 |
+
gb.configure_selection(selection_mode="single", use_checkbox=True, suppressRowDeselection=False)
|
228 |
+
gridOptions = gb.build()
|
229 |
+
jp_article = AgGrid(news_articles, gridOptions=gridOptions, theme='material',
|
230 |
+
enable_enterprise_modules=True,
|
231 |
+
allow_unsafe_jscode=True, update_mode=GridUpdateMode.SELECTION_CHANGED)
|
232 |
+
|
233 |
+
# WHITESPACE_HANDLER = lambda k: re.sub('\s+', ' ', re.sub('\n+', ' ', k.strip()))
|
234 |
+
if len(jp_article['selected_rows']) == 0:
|
235 |
+
st.info("Pick any one Japanese News Article by selecting the checkbox. News articles can be navigated by clicking on page navigator at right-bottom")
|
236 |
+
else:
|
237 |
+
article_text = jp_article['selected_rows'][0]['News Articles']
|
238 |
+
|
239 |
+
text = st.text_area(label="Text from selected Japanese News Article(ニュース記事)", value=article_text, height=500)
|
240 |
+
summary_length = st.slider(label="Select the maximum length of summary (要約の最大長を選択します )", min_value=120,max_value=160,step=5)
|
241 |
+
|
242 |
+
if text and st.button("Summarize it! (要約しよう)"):
|
243 |
+
waitPlaceholder = st.image("wait.gif")
|
244 |
+
summarization_model_name = "csebuetnlp/mT5_multilingual_XLSum"
|
245 |
+
tokenizer = AutoTokenizer.from_pretrained(summarization_model_name )
|
246 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(summarization_model_name )
|
247 |
+
|
248 |
+
input_ids = tokenizer(
|
249 |
+
article_text,
|
250 |
+
return_tensors="pt",
|
251 |
+
padding="max_length",
|
252 |
+
truncation=True,
|
253 |
+
max_length=512
|
254 |
+
)["input_ids"]
|
255 |
+
|
256 |
+
output_ids = model.generate(
|
257 |
+
input_ids=input_ids,
|
258 |
+
max_length=summary_length,
|
259 |
+
no_repeat_ngram_size=2,
|
260 |
+
num_beams=4
|
261 |
+
)[0]
|
262 |
+
|
263 |
+
summary = tokenizer.decode(
|
264 |
+
output_ids,
|
265 |
+
skip_special_tokens=True,
|
266 |
+
clean_up_tokenization_spaces=False
|
267 |
+
)
|
268 |
+
|
269 |
+
waitPlaceholder.empty()
|
270 |
+
|
271 |
+
st.markdown(
|
272 |
+
"<h2 style='text-align: left; color:#32CD32; font-size:25px;'><b>Summary (要約文)<b></h2>",
|
273 |
+
unsafe_allow_html=True)
|
274 |
+
|
275 |
+
st.write(summary)
|
276 |
+
elif topic == "Japanese to English Translation":
|
277 |
+
st.markdown(
|
278 |
+
"<h2 style='text-align: left; color:#EE82EE; font-size:25px;'><b>Japanese to English translation (for short sentences)<b></h2>",
|
279 |
+
unsafe_allow_html=True)
|
280 |
+
st.markdown(
|
281 |
+
"<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Business Scene Dialog Japanese-English Corpus<b></h3>",
|
282 |
+
unsafe_allow_html=True)
|
283 |
+
|
284 |
+
st.write("Below given Japanese-English pair is from 'Business Scene Dialog Corpus' by the University of Tokyo")
|
285 |
+
link = '[Corpus GitHub Link](https://github.com/tsuruoka-lab/BSD)'
|
286 |
+
st.markdown(link, unsafe_allow_html=True)
|
287 |
+
|
288 |
+
bsd_more_info = st.expander(label="Expand to get more information on data and training report")
|
289 |
+
with bsd_more_info:
|
290 |
+
st.markdown(
|
291 |
+
"<h3 style='text-align: left; color:#F63366; font-size:12px;'><b>Training Dataset<b></h3>",
|
292 |
+
unsafe_allow_html=True)
|
293 |
+
st.write("The corpus has total 20,000 Japanese-English Business Dialog pairs. The fined-tuned Transformer model is validated on 670 Japanese-English Business Dialog pairs")
|
294 |
+
|
295 |
+
st.markdown(
|
296 |
+
"<h3 style='text-align: left; color:#F63366; font-size:12px;'><b>Training Report<b></h3>",
|
297 |
+
unsafe_allow_html=True)
|
298 |
+
st.write(
|
299 |
+
"The Dashboard for training result on Tensorboard is [here](https://tensorboard.dev/experiment/eWhxt1i2RuaU64krYtORhw/)")
|
300 |
+
|
301 |
+
with open("./BSD_ja-en_val.json", encoding='utf-8') as f:
|
302 |
+
bsd_sample_data = json.load(f)
|
303 |
+
|
304 |
+
en, ja = [], []
|
305 |
+
for i in range(len(bsd_sample_data)):
|
306 |
+
for j in range(len(bsd_sample_data[i]['conversation'])):
|
307 |
+
en.append(bsd_sample_data[i]['conversation'][j]['en_sentence'])
|
308 |
+
ja.append(bsd_sample_data[i]['conversation'][j]['ja_sentence'])
|
309 |
+
|
310 |
+
df = pd.DataFrame.from_dict({'Japanese': ja, 'English': en})
|
311 |
+
gb = GridOptionsBuilder.from_dataframe(df)
|
312 |
+
gb.configure_pagination()
|
313 |
+
gb.configure_selection(selection_mode="single", use_checkbox=True, suppressRowDeselection=False)
|
314 |
+
gridOptions = gb.build()
|
315 |
+
translation_text = AgGrid(df, gridOptions=gridOptions, theme='material',
|
316 |
+
enable_enterprise_modules=True,
|
317 |
+
allow_unsafe_jscode=True, update_mode=GridUpdateMode.SELECTION_CHANGED)
|
318 |
+
if len(translation_text['selected_rows']) != 0:
|
319 |
+
bsd_jp = translation_text['selected_rows'][0]['Japanese']
|
320 |
+
st.markdown(
|
321 |
+
"<h2 style='text-align: left; color:#32CD32; font-size:25px;'><b>Business Scene Dialog in Japanese (日本語でのビジネスシーンダイアログ)<b></h2>",
|
322 |
+
unsafe_allow_html=True)
|
323 |
+
st.write(bsd_jp)
|
324 |
+
|
325 |
+
if st.button("Translate"):
|
326 |
+
ja_tokenizer, en_tokenizer = getJpEn_Tokenizers()
|
327 |
+
trained_model = loadFineTunedJaEn_NMT_Model()
|
328 |
+
trained_model.freeze()
|
329 |
+
|
330 |
+
|
331 |
+
def translate(text):
|
332 |
+
text_encoding = ja_tokenizer(
|
333 |
+
text,
|
334 |
+
max_length=100,
|
335 |
+
padding="max_length",
|
336 |
+
truncation=True,
|
337 |
+
return_attention_mask=True,
|
338 |
+
add_special_tokens=True,
|
339 |
+
return_tensors='pt'
|
340 |
+
)
|
341 |
+
|
342 |
+
generated_ids = trained_model.model.generate(
|
343 |
+
input_ids=text_encoding['input_ids'],
|
344 |
+
attention_mask=text_encoding['attention_mask'],
|
345 |
+
max_length=100,
|
346 |
+
num_beams=2,
|
347 |
+
repetition_penalty=2.5,
|
348 |
+
length_penalty=1.0,
|
349 |
+
early_stopping=True
|
350 |
+
)
|
351 |
+
|
352 |
+
preds = [en_tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True) for
|
353 |
+
gen_id in generated_ids]
|
354 |
+
|
355 |
+
return "".join(preds)[5:]
|
356 |
+
|
357 |
+
|
358 |
+
st.markdown(
|
359 |
+
"<h2 style='text-align: left; color:#32CD32; font-size:25px;'><b>Translated Dialog in English (英語の翻訳されたダイアログ)<b></h2>",
|
360 |
+
unsafe_allow_html=True)
|
361 |
+
st.write(translate(bsd_jp))
|